Модель Keras имеет разные потери - PullRequest
0 голосов
/ 13 апреля 2020

Я сейчас работаю над моделью Keras. Моя цель - иметь как можно более низкую среднеквадратическую среднюю величину, поэтому я устанавливаю потери и метрики в среднеквадратичном исчислении. После обучения модели, когда я вычисляю среднеквадратичное отклонение модели по данным обучения, я получаю более высокий результат. Почему это так?

Моя функция RMSE для этой модели:

from keras import backend as K
def root_mean_squared_error(y_true, y_pred):
        return K.sqrt(K.mean(K.square(y_pred - y_true)))

После обучения модели я вычисляю свое значение rmse с помощью этой функции

from sklearn.metrics import mean_squared_error
from math import sqrt
def measure_mse(actual, predicted):
    return  sqrt(mean_squared_error(actual, predicted))

measure_mse(train_y, model.predict(train_x))

================================================== ==================

layer_in = Input(shape=(16,1))

layer_regr = GRU(64, activation='relu', kernel_regularizer=regularizers.l2(0.01), kernel_initializer='truncated_normal', return_sequences=True)(layer_in)

layer_regr = Dropout(0.2)(layer_regr)

layer_regr = GRU(16, activation='relu', kernel_initializer='truncated_normal', return_sequences=True)(layer_regr)

layer_regr = GRU(64, activation='relu', kernel_initializer='truncated_normal')(layer_regr)

layer_regr = Dropout(0.2)(layer_regr)

layer_regr = Dense(16, activation='relu', kernel_initializer='truncated_normal')(layer_regr)

layer_out = Dense(1,)(layer_regr)

model = Model(inputs=layer_in, outputs=layer_out)

model.compile(loss=root_mean_squared_error, optimizer='adam', metrics=[root_mean_squared_error])

checkpointer = ModelCheckpoint(filepath="weights.hdf5", verbose=1, save_best_only=True)

model.fit(train_x, train_y, epochs=10, batch_size = 16, validation_data=(test_x, test_y), verbose=1, callbacks=[checkpointer])

model.load_weights('weights.hdf5')

======================= ===================================================

результаты:

Модель keras: val_root_mean_squared_error: 1.8079

Я вычисляю: 2.1155

...