Как заставить `fit_generator` работать с` tf.keras.Model` - PullRequest
1 голос
/ 27 мая 2019

Я реализую tf.keras.Model (не Sequential модель!), Который должен быть обучен с использованием fit_generator.Однако fit_generator вызывает ошибку, возможно, из-за того, что входные фигуры недоступны во время компиляции.

Вот минимальный пример:

import tensorflow as tf
import numpy as np


class MyModel(tf.keras.Model):

    def __init__(self):
        super(MyModel, self).__init__()
        self.dense1 = tf.keras.layers.Dense(3, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(3, activation=tf.nn.softmax)

    def call(self, inputs, training=None, mask=None):
        return self.dense2(self.dense1(inputs))


class MyGenerator(tf.keras.utils.Sequence):

    def __len__(self):
        # Number of batches per epoch
        return 1

    def __getitem__(self, _):
        # Generate one batch of data
        x = np.array([[1., 2., 3.]])
        y = np.array([[0., 1., 0.5]])

        return x, y


if __name__ == '__main__':
    m = MyModel()    
    g = MyGenerator()

    m.compile(tf.keras.optimizers.SGD(), loss=tf.keras.losses.mean_squared_error)
    m.fit_generator(g)

Последняя строка поднимает

AttributeError: 'MyModel' object has no attribute 'total_loss'

Итак, как правильно использовать fit_generator в пользовательской модели Keras?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...