Как получить историю обратных вызовов в Керасе? - PullRequest
0 голосов
/ 10 февраля 2019

Как я могу получить историю метрик обратного вызова?У меня есть класс Metrics, и я использую его в функции fit модели Keras следующим образом callbacks=[model_metrics].

Это полный код функции класса Metrics и fit.

class Metrics(Callback):

    def on_train_begin(self, logs={}):
        self.val_f1s = []
        self.val_bal_accs = []

    def on_epoch_end(self, epoch, logs={}):
        val_predict = np.argmax((np.asarray(self.model.predict(self.validation_data[0]))).round(), axis=1)
        val_targ = np.argmax(self.validation_data[1], axis=1)
        _val_f1 = metrics.f1_score(val_targ, val_predict, average='weighted')
        _val_bal_acc = metrics.balanced_accuracy_score(val_targ, val_predict)    
        self.val_f1s.append(_val_f1)
        self.val_bal_accs.append(_val_bal_acc)
        print(" — val_f1: {:f} — val_bal_acc: {:f}".format(_val_f1, _val_bal_acc))
        return

model_metrics = Metrics()

history = model.fit(np.array(X_train), y_train, 
                    validation_data=(np.array(X_test), y_test),
                    epochs=5,
                    batch_size=2,
                    callbacks=[model_metrics],
                    shuffle=False,
                    verbose=1)

Как я могу получить history из val_f1 и val_bal_acc?Теперь я могу получить доступ только к loss, val_loss, acc, val_acc:

print(history.history.keys())

1 Ответ

0 голосов
/ 10 февраля 2019

Чтобы взаимодействовать с keras историческим API, вам нужно передать аргументы для metrics, а не callbacks.

В текущем состоянии ваши val_f1 и val_bal_acc не будут храниться в объекте истории, а скорее будут храниться в вашем model_metrics объекте.

Вы можете получить доступони выглядят так:

model_metrics.val_f1s

Это то же самое, что и доступ к атрибуту для любого объекта.

Наконец, если вы хотите создать собственную метрику и хотите получить доступ к ней из истории, вам нужно определитьпользовательскую метрику (как функцию) и затем передайте ее в metrics kwarg в model.compile.Это делается так:

def my_metric(y_true y_pred):
    return y_true # just a dummy return value

# assume that the model is defined somewhere
model.compile(loss=..., optimizer=..., metrics = [my_metric]

И тогда вы сможете найти val_my_metric в историческом объекте, который вам не по силам.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...