В кандидате на выпуск TF2.2 другой способ обучения заключается в создании функции обучения с помощью training_function = tf.keras.Model.make_train_function()
, которая при вызове выполняет один шаг обучения.
training_function(data)
Другим способом обучения является использование tf.keras.Model.train_on_batch(data)
. Тем не менее, я обнаружил, что разница в производительности составляет около 25% времени.
Мне было интересно, есть ли какая-либо причина, по которой использование метода training_function=tf.keras.Model.make_train_function();training_function(data)
для обучения будет быстрее, чем tf.keras.Model.train_on_batch()
?
(Некоторые другие детали: я установил TF с помощью conda, поэтому я на самом деле использую TF2.1 и реализовал TF2.2. make_train_function()
, используя «_make_train_function()
» (обратите внимание на подчеркивание) в TF2.1:
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras import backend as K
model = ... # Some Keras model
data = ... # Some TF dataset
_,_,_ = model._standardize_user_data(data, None)
training_function = model._make_train_function()
data,data_y_None,data_sampleweights_None = model._standardize_user_data(data, None)
data = training_utils.ModelInputs(data).as_list()
data = data + list(data_y_None or []) + list(data_sampleweights_None or [])
if not isinstance(K.symbolic_learning_phase(), int):
data += [True]
training_function(data)
Я не совсем уверен, почему этот метод обучения выполняет намного быстрее, но тренировочную работу. Любая помощь будет оценена :) Я надеюсь, что есть что-то действительно очевидное, что я пропал)
ОБНОВЛЕНИЕ: быстрая проверка кода для TF2.1 и TF2.2-r c показывает, что train_on_batch
звонит (_)make_train_function
каждый раз, когда train_on_batch
вызывается, какие счета за дополнительные 25% времени. Теперь возникает вопрос, почему train_on_batch
каждый раз воссоздает функцию обучения?
ОБНОВЛЕНИЕ: Функция обучения создается только один раз, поскольку функция обучения затем сохраняется как свойство объекта. Однако обучающая функция создается заново (и свойство объекта перезаписывается), если обнаруживает, что модель перекомпилирована с момента последнего вызова. По какой-то причине эта перекомпиляция запускается в моем коде, вызывая пересоздание обучающей функции на каждом шаге, и я не знаю почему. Без полной диагностики c пример, вам трудно помочь, но мне интересно, связано ли это со мной, используя функции .add_metric()
и .add_loss()
для пользовательских tf.keras.Model ().