Keras.model.summary неправильно отображает мою модель ..? - PullRequest
3 голосов
/ 05 мая 2020

Я хочу просмотреть сводку моей модели через keras.model.summary, но это не работает. Мой код выглядит следующим образом:

class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = Conv2D(32,3,activation = 'relu')
        self.flatten = Faltten()
        self.d1 = Dense(128, activation = 'relu')
        self.d2 = Dense(10, activation = 'relu')

    def trythis(self,x):
        a = BatchNormalization()
        b = a(x)
        return b

    def call(self, x):
        x = self.conv1(x)
        x = trythis(x)
        x = self.flatten(x)
        x = self.d1(x)
        return self.d2(x)

model = MyModel()
model.build((None, 32,32,3))
model.summary()

Я ожидал, что уровень BatchNorm, но краткое изложение выглядит следующим образом:

Model: "my_model_30"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_31 (Conv2D)           multiple                  896       
_________________________________________________________________
flatten_30 (Flatten)         multiple                  0         
_________________________________________________________________
dense_60 (Dense)             multiple                  3686528   
_________________________________________________________________
dense_61 (Dense)             multiple                  1290      
=================================================================
Total params: 3,688,714
Trainable params: 3,688,714
Non-trainable params: 0

Он не содержит уровень BatchNorm в методе 'trythis'.

Как решить эту проблему?

Спасибо за внимание.

1 Ответ

1 голос
/ 05 мая 2020

Вывод формы подклассовой модели не c как в функциональном API. Поэтому я добавил вызов модели в подкласс модели и определил функциональную модель, как показано ниже. Обратите внимание, что есть несколько способов сделать, и я показываю только один из них. Пожалуйста, проверьте более подробную информацию в аналогичном вопросе, на который я ответил здесь

import tensorflow as tf
from tensorflow import keras

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, Dense, Flatten, BatchNormalization

class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = Conv2D(32,3,activation = 'relu')
        self.flatten = Flatten()
        self.d1 = Dense(128, activation = 'relu')
        self.d2 = Dense(10, activation = 'relu')

    def trythis(self,x):
        a = BatchNormalization()
        b = a(x)
        return b

    def call(self, x):
        x = self.conv1(x)
        x = MyModel.trythis(self,x)
        x = self.flatten(x)
        x = self.d1(x)
        return self.d2(x)
    def model(self):
        x = tf.keras.layers.Input(shape=(32, 32, 3))
        return Model(inputs=[x], outputs=self.call(x))

model = MyModel()
model_functional = model.model()
#model.build((None, 32,32,3))
model_functional.summary()

Резюме выглядит следующим образом

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 32, 32, 3)]       0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 30, 30, 32)        896       
_________________________________________________________________
batch_normalization (BatchNo (None, 30, 30, 32)        128       
_________________________________________________________________
flatten_4 (Flatten)          (None, 28800)             0         
_________________________________________________________________
dense_8 (Dense)              (None, 128)               3686528   
_________________________________________________________________
dense_9 (Dense)              (None, 10)                1290      
=================================================================
Total params: 3,688,842
Trainable params: 3,688,778
Non-trainable params: 64
_________________________________________________________________
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...