model.summary () не может напечатать выходную форму при использовании модели подкласса - PullRequest
2 голосов
/ 19 марта 2019

Это два метода для создания модели keras, но output shapes суммарных результатов двух методов различны.Очевидно, что первый выводит больше информации и облегчает проверку правильности сети.

import tensorflow as tf
from tensorflow.keras import Input, layers, Model

class subclass(Model):
    def __init__(self):
        super(subclass, self).__init__()
        self.conv = layers.Conv2D(28, 3, strides=1)

    def call(self, x):
        return self.conv(x)


def func_api():
    x = Input(shape=(24, 24, 3))
    y = layers.Conv2D(28, 3, strides=1)(x)
    return Model(inputs=[x], outputs=[y])

if __name__ == '__main__':
    func = func_api()
    func.summary()

    sub = subclass()
    sub.build(input_shape=(None, 24, 24, 3))
    sub.summary()

output:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 24, 24, 3)         0         
_________________________________________________________________
conv2d (Conv2D)              (None, 22, 22, 28)        784       
=================================================================
Total params: 784
Trainable params: 784
Non-trainable params: 0
_________________________________________________________________
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_1 (Conv2D)            multiple                  784       
=================================================================
Total params: 784
Trainable params: 784
Non-trainable params: 0
_________________________________________________________________

Итак, как мне использовать метод подкласса, чтобы получитьoutput shape на резюме ()?

1 Ответ

2 голосов
/ 19 марта 2019

Я использовал этот метод для решения этой проблемы, я не знаю, есть ли более простой способ.

class subclass(Model):
    def __init__(self):
        ...
    def call(self, x):
        ...

    def model():
        x = Input(shape=(24, 24, 3))
        return Model(inputs=[x], outputs=self.call(x))



if __name__ == '__main__':
    sub = subclass()
    sub.model().summary()
...