Знает ли TF train_on_batch () о tf.distribute.Strategy ()? - PullRequest
0 голосов
/ 25 февраля 2020

Я пытаюсь распределить обучение для общей состязательной сети (GAN) на двух графических процессорах с новой настройкой tf.distribute.Strategy. Я знаю, что метод Model.fit () знает о стратегии распределения и может быть вызван следующим образом:

    with strategy.scope():
        dataset = ...
        model = ...
        optimizer = ...
        model.compile(loss = "...", optimizer=optimizer)
        model.fit(dataset, epochs=10)

С другой стороны, пользовательские циклы обучения (то есть ручное применение градиентов) требуют дополнительных шагов например, ручное распределение вашего набора данных в области действия с использованием strategy.experimental_distribute_dataset(dataset), указание reduction=tf.keras.reduction.None в определении функции потерь и вызов каждого этапа обучения с использованием strategy.experimental_run_v2(train_step,inputs).

Мне интересно, что мне следует делать в промежуточный случай; мой тренировочный шаг просто вызывает train_on_batch () следующим образом:

    def train_step(self,g_input,g_output,gd_output,d_input,d_output):
        # Train discriminator
        d_loss = tf.constant(self.discriminator.train_on_batch(d_input, d_output))

        # Train cGAN
        total_loss, g_loss, gd_loss = tf.constant(self.cgan.train_on_batch(g_input, [g_output, gd_output]))

        return d_loss, total_loss, g_loss, gd_loss

, в отличие от ручного применения градиентов. Модели (self.discriminator и self.cgan) были скомпилированы следующим образом:

    self.cgan.compile(loss=['mae', 'binary_crossentropy'],
                      loss_weights=[1, 1],
                      optimizer=optimizer)

    self.discriminator.compile(loss='binary_crossentropy',
                               optimizer=optimizer)

Могу ли я просто вызвать train_on_batch () в рамках стратегии распространения (как в случае Model.fit () ), или мне нужно вручную распределять каждый шаг обучения, используя strategy.experimental_run_v2(), и усреднять результаты, используя метод strategy.reduce()?

...