Keras версия 2.2.4, tenorflow версия 1.13.1, я использую ноутбуки colab
Я пытаюсь создать собственный инициализатор и сохранить модель с помощью model.save (), но при загрузкемодель снова я получаю следующую ошибку:
TypeError: myInit () отсутствует 1 обязательный позиционный аргумент: 'input_shape'
У меня есть следующий код:
import numpy as np
import tensorflow as tf
import keras
from google.colab import drive
from keras.models import Sequential, load_model
from keras.layers import Dense, Dropout, Flatten, Lambda, Reshape, Activation
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras import backend as K
K.set_image_data_format('channels_first')
K.backend()
# the output should be 'tensorflow'
tenensflow
def myInit( input_shape, dtype=None):
weights = np.full( input_shape, 2019 )
return K.variable( weights, dtype=dtype )
Этот инициализатор получает input_shape и возвращает тензор keras, как в документах: https://keras.io/initializers/
model = Sequential()
model.add(
Dense( 40, input_shape=(784,) )
)
model.add(
Dense( 30, kernel_initializer=myInit )
)
model.add(
Dense( 5 )
)
model.build()
Веса инициализируются правильно, потому что когда я вызываю model.layers[1].get_weights()
, я получаю массив, полный 2019. Я сохраняю модель, используя model.save:
model.save(somepath)
В другой записной книжке я затем вызываю
model = load_model(somepath,
custom_objects={
'tf' : tf,
'myInit' : myInit
}
)
В этом блокноте myInit
и весь импорт также определены.Когда я вызываю load_model
, я получаю следующую ошибку:
TypeError: myInit () отсутствует 1 обязательный позиционный аргумент: 'input_shape'
Так кажется, когда модельзагружен, input_shape не передается myInit.У кого-нибудь есть идеи?
Полный след:
TypeError Traceback (most recent call last)
<ipython-input-25-544d137de03f> in <module>()
2 custom_objects={
3 'tf' : tf,
----> 4 'myInit' : myInit
5 }
6 )
/usr/local/lib/python3.6/dist-packages/keras/engine/saving.py in load_model(filepath, custom_objects, compile)
417 f = h5dict(filepath, 'r')
418 try:
--> 419 model = _deserialize_model(f, custom_objects, compile)
420 finally:
421 if opened_new_file:
/usr/local/lib/python3.6/dist-packages/keras/engine/saving.py in _deserialize_model(f, custom_objects, compile)
223 raise ValueError('No model found in config.')
224 model_config = json.loads(model_config.decode('utf-8'))
--> 225 model = model_from_config(model_config, custom_objects=custom_objects)
226 model_weights_group = f['model_weights']
227
/usr/local/lib/python3.6/dist-packages/keras/engine/saving.py in model_from_config(config, custom_objects)
456 '`Sequential.from_config(config)`?')
457 from ..layers import deserialize
--> 458 return deserialize(config, custom_objects=custom_objects)
459
460
/usr/local/lib/python3.6/dist-packages/keras/layers/__init__.py in deserialize(config, custom_objects)
53 module_objects=globs,
54 custom_objects=custom_objects,
---> 55 printable_module_name='layer')
/usr/local/lib/python3.6/dist-packages/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
143 config['config'],
144 custom_objects=dict(list(_GLOBAL_CUSTOM_OBJECTS.items()) +
--> 145 list(custom_objects.items())))
146 with CustomObjectScope(custom_objects):
147 return cls.from_config(config['config'])
/usr/local/lib/python3.6/dist-packages/keras/engine/sequential.py in from_config(cls, config, custom_objects)
298 for conf in layer_configs:
299 layer = layer_module.deserialize(conf,
--> 300 custom_objects=custom_objects)
301 model.add(layer)
302 if not model.inputs and build_input_shape:
/usr/local/lib/python3.6/dist-packages/keras/layers/__init__.py in deserialize(config, custom_objects)
53 module_objects=globs,
54 custom_objects=custom_objects,
---> 55 printable_module_name='layer')
/usr/local/lib/python3.6/dist-packages/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
145 list(custom_objects.items())))
146 with CustomObjectScope(custom_objects):
--> 147 return cls.from_config(config['config'])
148 else:
149 # Then `cls` may be a function returning a class.
/usr/local/lib/python3.6/dist-packages/keras/engine/base_layer.py in from_config(cls, config)
1107 A layer instance.
1108 """
-> 1109 return cls(**config)
1110
1111 def count_params(self):
/usr/local/lib/python3.6/dist-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
89 warnings.warn('Update your `' + object_name + '` call to the ' +
90 'Keras 2 API: ' + signature, stacklevel=2)
---> 91 return func(*args, **kwargs)
92 wrapper._original_function = func
93 return wrapper
/usr/local/lib/python3.6/dist-packages/keras/layers/core.py in __init__(self, units, activation, use_bias, kernel_initializer, bias_initializer, kernel_regularizer, bias_regularizer, activity_regularizer, kernel_constraint, bias_constraint, **kwargs)
846 self.activation = activations.get(activation)
847 self.use_bias = use_bias
--> 848 self.kernel_initializer = initializers.get(kernel_initializer)
849 self.bias_initializer = initializers.get(bias_initializer)
850 self.kernel_regularizer = regularizers.get(kernel_regularizer)
/usr/local/lib/python3.6/dist-packages/keras/initializers.py in get(identifier)
509 elif isinstance(identifier, six.string_types):
510 config = {'class_name': str(identifier), 'config': {}}
--> 511 return deserialize(config)
512 elif callable(identifier):
513 return identifier
/usr/local/lib/python3.6/dist-packages/keras/initializers.py in deserialize(config, custom_objects)
501 module_objects=globals(),
502 custom_objects=custom_objects,
--> 503 printable_module_name='initializer')
504
505
/usr/local/lib/python3.6/dist-packages/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
152 custom_objects = custom_objects or {}
153 with CustomObjectScope(custom_objects):
--> 154 return cls(**config['config'])
155 elif isinstance(identifier, six.string_types):
156 function_name = identifier
TypeError: myInit() missing 1 required positional argument: 'input_shape'
Примечание. Я также разместил это на https://github.com/keras-team/keras/issues/12452, но я подумал, что это было бы лучшим местом для этого.