AttributeError: validation_data не может быть доступен в Custom Callback - PullRequest
0 голосов
/ 30 сентября 2019

Я реализовал модель прогнозирования с помощью LSTM и написал собственный обратный вызов, чтобы получить доступ к некоторым дополнительным метрикам на обратном масштабированном входе.

Класс Metrics выглядит следующим образом:

class Metrics(keras.callbacks.Callback):
    def __init__(self, scaler):
        self.scaler = scaler

    def on_train_begin(self, logs):
        self._data = []

    def on_epoch_end(self, batch, logs):
        val_data, val_target = self.validation_data[0], self.validation_data[1]

        # calculating and appending the metric here
        # self._data.append({metric})

        return

    def get_data(self):
        return self._data

Я тогда использую это так:

metrics = Metrics(scaler)

model = Sequential()
model.add(LSTM(32, 
                   return_sequences=True,
                   activation='tanh', 
                   input_shape=(dataset.X_train.shape[1], dataset.X_train.shape[2])))
# more layers and model.compile here

history = model.fit(dataset.X_train, 
                    dataset.y_train,  
                    epochs=EPOCHS,  
                    validation_data=(dataset.X_valid, dataset.y_valid), 
                    callbacks=[metrics])

Есть идеи?

...