Я реализую tf.keras.Model
(не Sequential
модель!), Который должен быть обучен с использованием fit_generator
.Однако fit_generator
вызывает ошибку, возможно, из-за того, что входные фигуры недоступны во время компиляции.
Вот минимальный пример:
import tensorflow as tf
import numpy as np
class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__()
self.dense1 = tf.keras.layers.Dense(3, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(3, activation=tf.nn.softmax)
def call(self, inputs, training=None, mask=None):
return self.dense2(self.dense1(inputs))
class MyGenerator(tf.keras.utils.Sequence):
def __len__(self):
# Number of batches per epoch
return 1
def __getitem__(self, _):
# Generate one batch of data
x = np.array([[1., 2., 3.]])
y = np.array([[0., 1., 0.5]])
return x, y
if __name__ == '__main__':
m = MyModel()
g = MyGenerator()
m.compile(tf.keras.optimizers.SGD(), loss=tf.keras.losses.mean_squared_error)
m.fit_generator(g)
Последняя строка поднимает
AttributeError: 'MyModel' object has no attribute 'total_loss'
Итак, как правильно использовать fit_generator
в пользовательской модели Keras?