подкласс tf.keras.Model не может получить результат sumy () - PullRequest
1 голос
/ 15 апреля 2019

Я хочу построить подкласс tf.keras.Model и хочу увидеть структуру модели с функцией summary.Но это не работает.Ниже мой код:

import tensorflow as tf

class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(32, 3, activation='relu')
        self.flatten = tf.keras.layers.Flatten()
        self.d1 = tf.keras.layers.Dense(128, activation='relu')
        self.d2 = tf.keras.layers.Dense(10, activation='softmax')

    def call(self, x):
        x = self.conv1(x)
        x = self.flatten(x)
        x = self.d1(x)
        return self.d2(x)

model = MyModel()
model.summary()

Ошибка:

ValueError: Эта модель еще не была построена.Сначала создайте модель, вызвав build() или вызвав fit() с некоторыми данными, или укажите аргумент input_shape в первом слое (ях) для автоматической сборки.

1 Ответ

2 голосов
/ 15 апреля 2019

Вам нужно вызвать каждый слой один раз, чтобы вывести формы, а затем вызвать build() метод tf.keras.Model с входной формой модели в качестве аргумента:

import tensorflow as tf
import numpy as np

class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(32, 3, activation='relu')
        self.flatten = tf.keras.layers.Flatten()
        self.d1 = tf.keras.layers.Dense(128, activation='relu')
        self.d2 = tf.keras.layers.Dense(10, activation='softmax')
        x = np.random.normal(size=(1, 32, 32, 3))
        x = tf.convert_to_tensor(x)
        _ = self.call(x)

    def call(self, x):
        x = self.conv1(x)
        x = self.flatten(x)
        x = self.d1(x)
        return self.d2(x)

model = MyModel()
model.build((32, 32, 3))
model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              multiple                  896       
_________________________________________________________________
flatten (Flatten)            multiple                  0         
_________________________________________________________________
dense (Dense)                multiple                  3686528   
_________________________________________________________________
dense_1 (Dense)              multiple                  1290      
=================================================================
Total params: 3,688,714
Trainable params: 3,688,714
Non-trainable params: 0
_________________________________________________________________


...