Я решил перейти с keras на tf.keras (как рекомендовано здесь ). Поэтому я установил tf.__version__=2.0.0
и tf.keras.__version__=2.2.4-tf
. В более старой версии моего кода (с использованием более старой версии Tensorflow tf.__version__=1.x.x
) я использовал обратный вызов для вычисления пользовательских метрик для всех данных проверки в конце каждой эпохи. Идея сделать это была взята из здесь . Однако создается впечатление, что атрибут «validation_data» устарел, поэтому следующий код больше не работает.
class ValMetrics(Callback):
def on_train_begin(self, logs={}):
self.val_all_mse = []
def on_epoch_end(self, epoch, logs):
val_predict = np.asarray(self.model.predict(self.validation_data[0]))
val_targ = self.validation_data[1]
val_epoch_mse = mse_score(val_targ, val_predict)
self.val_epoch_mse.append(val_epoch_mse)
# Add custom metrics to the logs, so that we can use them with
# EarlyStop and csvLogger callbacks
logs["val_epoch_mse"] = val_epoch_mse
print(f"\nEpoch: {epoch + 1}")
print("-----------------")
print("val_mse: {:+.6f}".format(val_epoch_mse))
return
Мой текущий обходной путь заключается в следующем. Я просто дал validation_data в качестве аргумента классу ValMetrics
:
class ValMetrics(Callback):
def __init__(self, validation_data):
super(Callback, self).__init__()
self.X_val, self.y_val = validation_data
Тем не менее у меня есть несколько вопросов: действительно ли атрибут «validation_data» устарел или его можно найти в другом месте? Есть ли лучший способ получить доступ к данным проверки в конце каждой эпохи, чем с помощью вышеупомянутого обходного пути?
Большое спасибо!