Доступ к устаревшему атрибуту «validation_data» в tf.keras.callbacks.Callback - PullRequest
2 голосов
/ 05 февраля 2020

Я решил перейти с keras на tf.keras (как рекомендовано здесь ). Поэтому я установил tf.__version__=2.0.0 и tf.keras.__version__=2.2.4-tf. В более старой версии моего кода (с использованием более старой версии Tensorflow tf.__version__=1.x.x) я использовал обратный вызов для вычисления пользовательских метрик для всех данных проверки в конце каждой эпохи. Идея сделать это была взята из здесь . Однако создается впечатление, что атрибут «validation_data» устарел, поэтому следующий код больше не работает.

class ValMetrics(Callback):

    def on_train_begin(self, logs={}):

        self.val_all_mse = []

    def on_epoch_end(self, epoch, logs):

        val_predict = np.asarray(self.model.predict(self.validation_data[0]))
        val_targ = self.validation_data[1]

        val_epoch_mse = mse_score(val_targ, val_predict)

        self.val_epoch_mse.append(val_epoch_mse)

        # Add custom metrics to the logs, so that we can use them with
        # EarlyStop and csvLogger callbacks
        logs["val_epoch_mse"] = val_epoch_mse

        print(f"\nEpoch: {epoch + 1}")
        print("-----------------")
        print("val_mse:     {:+.6f}".format(val_epoch_mse))

        return

Мой текущий обходной путь заключается в следующем. Я просто дал validation_data в качестве аргумента классу ValMetrics:

class ValMetrics(Callback):

    def __init__(self, validation_data):
        super(Callback, self).__init__()
        self.X_val, self.y_val = validation_data

Тем не менее у меня есть несколько вопросов: действительно ли атрибут «validation_data» устарел или его можно найти в другом месте? Есть ли лучший способ получить доступ к данным проверки в конце каждой эпохи, чем с помощью вышеупомянутого обходного пути?

Большое спасибо!

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...