Я пытаюсь обучить GAN на TPU, поэтому я возился с классом TPUEstimator и сопутствующей функцией модели, чтобы попытаться реализовать обучающий цикл WGAN.Я пытаюсь использовать tf.cond
для объединения двух тренировочных операций для TPUEstimatorSpec следующим образом:
opt = tf.cond(
tf.equal(tf.mod(tf.train.get_or_create_global_step(),
CRITIC_UPDATES_PER_GEN_UPDATE+1), CRITIC_UPDATES_PER_GEN_UPDATE+1),
lambda: gen_opt,
lambda: critic_opt
)
gen_opt
и critic_opt
- это функция минимизации используемого оптимизатора, setобновить глобальный шаг, а также.CRITIC_UPDATES_PER_GEN_UPDATE
является константой Python именно для этого и является частью обучения WGAN.Я пытался найти модель GAN, используя tf.cond
, но все модели используют tf.group
, что я не могу использовать, потому что вам нужно оптимизировать критику намного больше, чем генератор.Однако каждый раз, когда я запускаю 100 партий, глобальный шаг увеличивается на 200 в соответствии с номером контрольной точки.Моя модель все еще тренируется правильно, или tf.cond
просто не должен использоваться для обучения GAN?