Проверка Tf.Keras NN не улучшилась - PullRequest
0 голосов
/ 11 октября 2019

Я столкнулся с проблемой во время обучения и проверки моей простой нейронной сети TensorFlow. Что я делаю не так с разделением проверки?

Код ниже. От эпохи к эпохе потеря моделей и ошибка MAE улучшались, но оценка и ошибка не улучшались вообще. Я не могу жить с этим счетом. Просто идите вперед - при тех же предварительно обработанных данных XGBoost Regressor обеспечивает значительно лучший результат ... Его нет в данных ...

создание модели

def build_model():
    model = keras.Sequential([
        layers.Dense(8, input_shape=[len(X.keys())], activation='relu'),
        layers.Dropout(0.3),
        layers.Dense(8, activation='relu'),
        layers.Dense(1)])
    optimizer = tf.keras.optimizers.RMSprop(0.001)
    model.compile(loss='mse',
                  optimizer=optimizer,
                  metrics=['mae', 'mse'])
    return model

параметров для проверки

k = 4
num_val_samples = len(X) // k
num_epochs = 500
all_scores = []

es = keras.callbacks.EarlyStopping(
    monitor='mean_squared_error',
    patience=10,
    verbose=1,
    restore_best_weights=True
)

расщепление, обучение, оценка, разочарование, LOL. Я действительно знаю, что здесь есть какая-то ошибка

for i in range(k):
    print('processing fold #', i)
    val_data = X[i * num_val_samples: (i + 1) * num_val_samples]
    val_targets = y[i * num_val_samples: (i + 1) * num_val_samples]
    partial_train_data = np.concatenate\
        ([X[:i * num_val_samples],
          X[(i + 1) * num_val_samples:]],
        axis=0)
    partial_train_targets = np.concatenate\
        ([y[:i * num_val_samples],
          y[(i + 1) * num_val_samples:]],
         axis=0)
    model = build_model()
    model.fit(
        partial_train_data,
        partial_train_targets,
        epochs=num_epochs,
        batch_size=16,
        validation_data=(val_data, val_targets,),
        callbacks=[es],
        verbose=1)
    val_mse = model.evaluate(val_data, val_targets, verbose=0)
    all_scores.append(val_mse)

print('scores MSE for each folds: ',all_scores)
print('average MSE total score is ',np.mean(all_scores))

текущие результаты - оценка за проверку ужасна

  • Epoch 1/500 1095/1095 [==============================] - 1s 863us / образец - потеря: 106344.0772 - mean_absolute_error: 213.7770 - mean_squared_error: 106344.0703 - val_loss: 39900570856.8548 - val_meanval_mean_squared_error: 39900573696.0000 Epoch 2/500 1095/1095 [=====================================] - 0s 228us / sample- потеря: 19830.6569
    • mean_absolute_error: 79.8033 - mean_squared_error: 19830.6562 - val_loss: 39893319971.7699 - val_mean_absolute_error: 184646.7344===_80_80_80_095_95_95_95_95_95_95=======================] - 0s 211us / sample - потеря: 6149.0242
    • mean_absolute_error: 38.5918 - mean_squared_error: 6149.0234 - val_loss: 39880093617.4466 - val_mean_absolute_error: 184613.3750 - val_mean_squared_error: 39880093696.0000 Epoch 4/500 1095/1095 [==============================] - 0s 219us / sample - потеря: 1326.3726
    • mean_absolute_error: 18.1357 - mean_squared_error: 1326.3724 - val_loss: 39890125899.7479 - val_mean_absolute_error: 184635.3750 - val_mean_squared_error: 39890124800.0000 Epoch 5/500 1095/1095 [=================================================] - 0 с 219us / образец - потеря: 1422,7218
    • mean_absolute_error: 12,3101 - mean_squared_error: 1422,7222 - val_loss: 39884738823,7151 - val_mean_absolute_error: 184 474 0 0 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 6,80,0 (*) * Все права защищены. / 1095 [=====================================] - 0 с 210 мкс / выборка - потеря: 369.3275 - mean_absolute_error: 9.4986 - mean_squared_error:369.3273 - val_loss: 39881353434.8274 - val_mean_absolute_error: 184615.2812 - val_mean_squared_error: 39881351168.0000 Epoch 7/500 1095/1095 [======================================] - 0s 219us / выборка - потеря: 79,9570 - mean_absolute_error: 6,5568 - mean_squared_error: 79,9570 - val_loss: 39877638890.2575 - val_mean_absolute_error: 184607.7656 - val_mean_squared_error: 39877644288.0000 Epoch 8/500 1095/1095 [===============================] - 0s 210us / образец - потери: 140,7780 - mean_absolute_error: 6,4829 - mean_squared_error: 140,7780 - * * val_loss тысяча тридцать-девять: +39882211653,4356 - val_mean_absolute_error: 184617,6719 - val_mean_squared_error: +39882215424,0000 * * +1040 эпоха 9/500 1095/1095 [==============================] - 0s 225us / sample - потеря: 57.5072 - mean_absolute_error: 5.8725 - mean_squared_error: 57.5072 - val_loss: 39883264151.4959 -val_mean_absolute_error: 184619.9375 - val_mean_squared_error: 39883268096.0000 и т. д. *
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...