Использование td.cond во время обучения приводит к снижению пропускной способности - PullRequest
0 голосов
/ 14 мая 2019

В процессе использования resnet50 для обучения imagenet мы использовали LARS для обновления скорости обучения и вычисления LR на каждом этапе обучения.Пропускная способность обучения составляет около 5500. Для этого мы намерены оптимизировать и рассчитывать операцию LR каждые несколько шагов для повышения пропускной способности.В исходном коде мы выполняем compute_lr вычисление на каждом шаге.

Я изменил код, как показано ниже:

  • Global_step - это тензор, используемый для наблюдения за тем, какой шаг обучения;
  • 2 - это константа, указывающая, что lr вычисляется каждые 2 шага.

Код:

def compute_lr()
    coumpte_lr 
       ...
    stored_lr
       ...
    return lr
def get_larsvalue()
    get_stored_lr
       ...
    return lr

tf.cond(tf.cast(tf.equal(tf.mod(gg,2),0),tf.bool),lambda:self.compute_lr(),lambda: self.get_larsvalue())

Но после того, как я изменилкод, пропускная способность упала.После анализа, я думаю, это потому, что tf.cond не ленивая операция, она выполнит обе ветви, что, очевидно, не то, что я хочу.Я не знаю, как писать код, чтобы завершить мои мысли сейчас, прошу всех помочь.

...