Python Keras: невозможно загрузить модель с пользовательским слоем, хотя он имеет get_config - PullRequest
0 голосов
/ 26 марта 2020

Я использовал пользовательский слой для моей модели Keras, а именно слой DepthwiseConv3D . Я обучил модель и сохранил ее, используя model.save("model.h5")

from DepthwiseConv3D import DepthwiseConv3D

model = load_model('model.h5',
          custom_objects={'DepthwiseConv3D': DepthwiseConv3D})

Но я получаю «Ошибка типа: неупорядоченные типы: NoneType ()> int ()», вызванная DepthWiseConv3D в:

if (self.groups > self.input_dim):
       raise ValueError('The number of groups cannot exceed the number of channels')

Конфигурация слоев:

 def get_config(self):
        config = super(DepthwiseConv3D, self).get_config()
        config.pop('filters')
        config.pop('kernel_initializer')
        config.pop('kernel_regularizer')
        config.pop('kernel_constraint')
        config['depth_multiplier'] = self.depth_multiplier
        config['depthwise_initializer'] = initializers.serialize(self.depthwise_initializer)
        config['depthwise_regularizer'] = regularizers.serialize(self.depthwise_regularizer)
        config['depthwise_constraint'] = constraints.serialize(self.depthwise_constraint)
        return config

Я создал свой слой как

x = DepthwiseConv3D(kernel_size=(7,7,7),
                depth_multiplier=1,groups=9, 
                padding ="same", use_bias=False,
                input_shape=(50, 37, 25, 9))(x)
x = DepthwiseConv3D(depth_multiplier= 32, groups=8, kernel_size=(7,7,7), 
                strides=(2,2,2), activation='relu', padding = "same")(x)
x = DepthwiseConv3D(depth_multiplier= 64, groups=8, kernel_size=(7,7,7), 
                strides=(2,2,2), activation='relu', padding = "same")(x)

Как я могу загрузить свою модель?

1 Ответ

1 голос
/ 28 марта 2020

Метод get_config в используемом вами пользовательском слое неправильно реализован, он не сохраняет все необходимые параметры, поэтому при загрузке модели возникают ошибки.

Если вы можете создать экземпляр модель, используя тот же оригинальный код, вы можете загрузить веса из того же файла, используя model.load_weights. Это просто обходной путь к проблеме, и он должен работать. Правильным решением было бы внедрить правильную версию get_config, что потребовало бы переобучения модели.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...