Существует очень большая разница между моделью подклассов керас и другими моделями керас (последовательная и функциональная).
Последовательные модели и функциональные модели представляют собой структуры данных, которые представляют DAG слоев. Проще говоря, функциональная или последовательная модель - это статические графики слоев, построенные путем наложения друг на друга, как LE GO. Поэтому, когда вы предоставляете input_shape первому слою, эти (функциональные и последовательные) модели могут вывести форму всех других слоев и построить модель. Затем вы можете распечатать формы ввода / вывода, используя model.summary()
.
С другой стороны, подклассовая модель определяется с помощью тела (метода вызова) кода Python. Для подклассовой модели здесь нет графика слоев. Мы не можем знать, как слои связаны друг с другом (потому что это определено в теле вызова, а не в виде явной структуры данных), поэтому мы не можем вывести формы ввода / вывода. Таким образом, для модели подкласса форма ввода / вывода нам неизвестна, пока она не будет сначала проверена с правильными данными. В методе compile () мы будем выполнять отложенную компиляцию и ждать правильных данных. Чтобы он мог определить форму промежуточных слоев, нам нужно запустить с правильными данными и затем использовать model.summary()
. Без запуска модели с данными она выдаст ошибку, как вы заметили.
Ниже приведен пример с сайта Tensorflow.
class ThreeLayerMLP(keras.Model):
def __init__(self, name=None):
super(ThreeLayerMLP, self).__init__(name=name)
self.dense_1 = layers.Dense(64, activation='relu', name='dense_1')
self.dense_2 = layers.Dense(64, activation='relu', name='dense_2')
self.pred_layer = layers.Dense(10, name='predictions')
def call(self, inputs):
x = self.dense_1(inputs)
x = self.dense_2(x)
return self.pred_layer(x)
def get_model():
return ThreeLayerMLP(name='3_layer_mlp')
model = get_model()
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255
model.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.RMSprop())
history = model.fit(x_train, y_train,
batch_size=64,
epochs=1)
model.summary()
Надеюсь, это поможет. Спасибо!