Я пишу этот пост после прочтения похожих вопросов и ответов, которые не сработали в моем случае. Вы можете заметить, что я определил форму ввода в первом слое.
Я создал очень маленький CNN в Керасе, как показано ниже:
import tensorflow as tf
class MyNet(tf.keras.Model):
def __init__(self):
super(MyNet, self).__init__()
self.conv1 = tf.keras.layers.Conv2D(32, 5, strides = (2,2), data_format = 'channels_first', input_shape = (3,224,224))
self.bn1 = tf.keras.layers.BatchNormalization(axis = 1)
self.fc1 = tf.keras.layers.Dense(10)
self.globalavg = tf.keras.layers.GlobalAveragePooling2D(data_format = 'channels_first')
def call(self, inputs):
x = self.conv1(inputs)
x = self.bn1(x)
x = tf.keras.activations.relu(x)
x = self.globalavg(x)
return self.fc1(x)
Затем я что-то ввел в него и распечатал результат успешно (весы, вероятно, случайны в данный момент, но это нормально):
image = tf.ones(shape = (1, 3, 224, 224)) # Defined "channels first" when created the layers
mynet = MyNet()
outputs = mynet(image)
print(tf.keras.backend.eval(outputs))
Результат, который я увидел на этом шаге, был 10 выходами слоя fc1
:
[[-1.1747773 -0.21640654 -0.16266493 -0.44879064 -0.642066 0.78132695 -0.03920581 -0.30874395 -0.04169023 -0.10409291]]
Затем я попытался сохранить модель с ее весами, вызвав mynet.save('mynet.hdf5')
, и получил следующую ошибку:
NotImplementedError: Currently `save` requires model to be a graph network. Consider using `save_weights`, in order to save the weights of the model.
Обратите внимание, что я новичок в Keras и что большая часть моего опыта связана с PyTorch.
Что я делаю не так?
Обновление:
После ответа @ ikibir я переопределил сеть как последовательную сеть:
myNetAsSeq = tf.keras.models.Sequential()
myNetAsSeq.add(tf.keras.layers.Conv2D(32, 5, strides = (2,2), data_format = 'channels_first', input_shape = (3,224,224)))
myNetAsSeq.add(tf.keras.layers.BatchNormalization(axis = 1))
myNetAsSeq.add(tf.keras.layers.Activation('relu'))
myNetAsSeq.add(tf.keras.layers.GlobalAveragePooling2D(data_format = 'channels_first'))
myNetAsSeq.add(tf.keras.layers.Dense(10))
На этот раз звонить myNetAsSeq.save('mynet.hdf5')
удалось.