Лучшие практики в Tensorflow 2.0 (шаг обучения) - PullRequest
0 голосов
/ 27 сентября 2019

В тензорном потоке 2.0 вам не нужно беспокоиться о фазе обучения (размер партии, количество эпох и т. Д.), Потому что все можно определить в методе compile: model.fit(X_train,Y_train,batch_size = 64,epochs = 100).

Но у меня естьвидел следующий стиль кода:

optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()

@tf.function
def train_step(inputs, labels):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training=True)
    regularization_loss = tf.math.add_n(model.losses)
    pred_loss = loss_fn(labels, predictions)
    total_loss = pred_loss + regularization_loss

  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

for epoch in range(NUM_EPOCHS):
  for inputs, labels in train_data:
    train_step(inputs, labels)
  print("Finished epoch", epoch)

Таким образом, здесь вы можете наблюдать «более подробный» код, где вы вручную определяете процедуру обучения для циклов.

У меня следующий вопрос: что такоелучшая практика в Tensorflow 2.0?Я не нашел полного учебника.

1 Ответ

2 голосов
/ 27 сентября 2019

Используйте то, что лучше для ваших нужд.

Оба метода описаны в руководствах по Tensorflow.

Если вам не нужно ничего особенного, без лишних потерь, странных показателей или сложных вычислений градиента, просто используйте model.fit() или model.fit_generator().Это совершенно нормально и делает вашу жизнь проще.

Пользовательский цикл обучения может пригодиться, если у вас есть сложные модели с нетривиальным расчетом потерь / градиентов.

До сих пор два приложения, которые я пробовал, были проще с этим:

  • Одновременная подготовка генератора и дискриминатора GAN без необходимости повторения шага генерации дважды.(Это сложно, потому что у вас есть функция потерь, которая применяется к различным значениям y_true, и каждый случай должен обновлять только часть модели) - Другой вариант потребовал бы иметь несколько отдельных моделей, каждая модель со своим собственным trainable=True/False Конфигурация, а затем обучить в отдельных фазах
  • Обучающие входные данные (хорошо для моделей передачи стилей) - в качестве альтернативы, создайте пользовательский слой, который принимает фиктивные входные данные и который выводит свои собственные обучаемые веса.Но становится сложным скомпилировать несколько функций потерь для каждого из выходов базовой сети и сети стилей.
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...