TensorFlow Keras: tf.keras.Model train_on_batch против make_train_function - Почему один медленнее, чем другой? - PullRequest
1 голос
/ 15 апреля 2020

В кандидате на выпуск TF2.2 другой способ обучения заключается в создании функции обучения с помощью training_function = tf.keras.Model.make_train_function(), которая при вызове выполняет один шаг обучения.

training_function(data)

Другим способом обучения является использование tf.keras.Model.train_on_batch(data). Тем не менее, я обнаружил, что разница в производительности составляет около 25% времени.

Мне было интересно, есть ли какая-либо причина, по которой использование метода training_function=tf.keras.Model.make_train_function();training_function(data) для обучения будет быстрее, чем tf.keras.Model.train_on_batch()?

(Некоторые другие детали: я установил TF с помощью conda, поэтому я на самом деле использую TF2.1 и реализовал TF2.2. make_train_function(), используя «_make_train_function()» (обратите внимание на подчеркивание) в TF2.1:

from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras import backend as K
model = ... # Some Keras model
data = ... # Some TF dataset
_,_,_ = model._standardize_user_data(data, None)
training_function = model._make_train_function()

data,data_y_None,data_sampleweights_None = model._standardize_user_data(data, None)
data = training_utils.ModelInputs(data).as_list()
data = data + list(data_y_None or []) + list(data_sampleweights_None or [])
if not isinstance(K.symbolic_learning_phase(), int):
    data += [True]
training_function(data)

Я не совсем уверен, почему этот метод обучения выполняет намного быстрее, но тренировочную работу. Любая помощь будет оценена :) Я надеюсь, что есть что-то действительно очевидное, что я пропал)

ОБНОВЛЕНИЕ: быстрая проверка кода для TF2.1 и TF2.2-r c показывает, что train_on_batch звонит (_)make_train_function каждый раз, когда train_on_batch вызывается, какие счета за дополнительные 25% времени. Теперь возникает вопрос, почему train_on_batch каждый раз воссоздает функцию обучения?

ОБНОВЛЕНИЕ: Функция обучения создается только один раз, поскольку функция обучения затем сохраняется как свойство объекта. Однако обучающая функция создается заново (и свойство объекта перезаписывается), если обнаруживает, что модель перекомпилирована с момента последнего вызова. По какой-то причине эта перекомпиляция запускается в моем коде, вызывая пересоздание обучающей функции на каждом шаге, и я не знаю почему. Без полной диагностики c пример, вам трудно помочь, но мне интересно, связано ли это со мной, используя функции .add_metric() и .add_loss() для пользовательских tf.keras.Model ().

1 Ответ

1 голос
/ 15 апреля 2020

train_on_batch звонит _make_execution_function() каждый раз в tensorflow1.15. Но функция выполнения создается только один раз (если она еще не создана или когда модель перекомпилируется) из-за оператора this if

if getattr(self, 'train_function', None) is None or has_recompiled:

В любом случае, _make_train/test/predict_function() частные функции и их цель - помочь разработчикам во внутренней реализации, и, как вы правильно заметили, их нет в tensorflow2.2.

В tensorflow2.2 у вас есть make_execution_function(), который полностью выполняется только один раз, потому что это if заявление

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...