Использование tf.cond () в функции модели оценки для обучения WGAN на TPU вызывает удвоение global_step - PullRequest
0 голосов
/ 27 января 2019

Я пытаюсь обучить GAN на TPU, поэтому я возился с классом TPUEstimator и сопутствующей функцией модели, чтобы попытаться реализовать обучающий цикл WGAN.Я пытаюсь использовать tf.cond для объединения двух тренировочных операций для TPUEstimatorSpec следующим образом:

opt = tf.cond(
    tf.equal(tf.mod(tf.train.get_or_create_global_step(), 
    CRITIC_UPDATES_PER_GEN_UPDATE+1), CRITIC_UPDATES_PER_GEN_UPDATE+1), 
    lambda: gen_opt, 
    lambda: critic_opt
)

gen_opt и critic_opt - это функция минимизации используемого оптимизатора, setобновить глобальный шаг, а также.CRITIC_UPDATES_PER_GEN_UPDATE является константой Python именно для этого и является частью обучения WGAN.Я пытался найти модель GAN, используя tf.cond, но все модели используют tf.group, что я не могу использовать, потому что вам нужно оптимизировать критику намного больше, чем генератор.Однако каждый раз, когда я запускаю 100 партий, глобальный шаг увеличивается на 200 в соответствии с номером контрольной точки.Моя модель все еще тренируется правильно, или tf.cond просто не должен использоваться для обучения GAN?

1 Ответ

0 голосов
/ 27 января 2019

tf.cond не должен использоваться таким образом для обучения GAN.

Вы получаете 200, потому что на каждом этапе обучения оцениваются побочные эффекты (например, операции назначения) и , true_fn и false_fn. Одним из побочных эффектов является глобальная операция шага tf.assign_add, которую определяют оба оптимизатора.

Следовательно, то, что происходит, похоже на

  • Исполнение global_step++ (gen_opt) и global_step++ (critic_op)
  • Оценка состояния
  • Исполнение true_fn тела или false_fn тела (в зависимости от состояния).

Если вы хотите обучить GAN с использованием tf.cond, вы должны удалить все побочные операции (например, назначение, отсюда и определение шага оптимизации) из true_fn / false_fn и объявить все внутри них.

В качестве ссылки вы можете увидеть этот ответ о поведении tf.cond: https://stackoverflow.com/a/37064128/2891324

...