Вывод формы подклассовой модели не c как в функциональном API. Поэтому я добавил вызов модели в подкласс модели и определил функциональную модель, как показано ниже. Обратите внимание, что есть несколько способов сделать, и я показываю только один из них. Пожалуйста, проверьте более подробную информацию в аналогичном вопросе, на который я ответил здесь
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, Dense, Flatten, BatchNormalization
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = Conv2D(32,3,activation = 'relu')
self.flatten = Flatten()
self.d1 = Dense(128, activation = 'relu')
self.d2 = Dense(10, activation = 'relu')
def trythis(self,x):
a = BatchNormalization()
b = a(x)
return b
def call(self, x):
x = self.conv1(x)
x = MyModel.trythis(self,x)
x = self.flatten(x)
x = self.d1(x)
return self.d2(x)
def model(self):
x = tf.keras.layers.Input(shape=(32, 32, 3))
return Model(inputs=[x], outputs=self.call(x))
model = MyModel()
model_functional = model.model()
#model.build((None, 32,32,3))
model_functional.summary()
Резюме выглядит следующим образом
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_5 (InputLayer) [(None, 32, 32, 3)] 0
_________________________________________________________________
conv2d_5 (Conv2D) (None, 30, 30, 32) 896
_________________________________________________________________
batch_normalization (BatchNo (None, 30, 30, 32) 128
_________________________________________________________________
flatten_4 (Flatten) (None, 28800) 0
_________________________________________________________________
dense_8 (Dense) (None, 128) 3686528
_________________________________________________________________
dense_9 (Dense) (None, 10) 1290
=================================================================
Total params: 3,688,842
Trainable params: 3,688,778
Non-trainable params: 64
_________________________________________________________________