Как пройти обучение и проверить потерю оболочки Keras scikit-learn при перекрестной проверке? - PullRequest
1 голос
/ 26 марта 2020

Я знаю, что model.fit в keras возвращает callback.Исторический объект, где мы можем получить потери и другие метрики из него следующим образом.

...
train_history = model.fit(X_train, Y_train,
                    batch_size=batch_size, nb_epoch=nb_epoch,
                    verbose=1, validation_data=(X_test, Y_test))
loss = train_history.history['loss']
val_loss = train_history.history['val_loss']

Однако в моем новом эксперименте enet Я используя cross validation с моделью keras используя kerasclassifier (полный пример кода: https://chrisalbon.com/deep_learning/keras/k-fold_cross-validating_neural_networks/)

# Wrap Keras model so it can be used by scikit-learn
neural_network = KerasClassifier(build_fn=create_network, 
                                 epochs=10, 
                                 batch_size=100, 
                                 verbose=1)

Поскольку сейчас я использую перекрестную проверку, я не уверен, как добраться на поезде и проверка утраты.

1 Ответ

3 голосов
/ 26 марта 2020

Как явно указано в документации , cross_val_score включает аргумент scoring, который

Похож на cross_validate, но только один показатель c разрешено.

, следовательно, его нельзя использовать для возврата всей информации о потерях и метри c информации о Keras model.fit().

Оболочка Scikit-learn Keras подразумевается как удобство, при условии, что вы на самом деле не интересуетесь всеми базовыми деталями (такими как обучение и проверка, потеря и точность). Если это не так, вам следует вернуться к использованию Keras напрямую. Вот как вы можете это сделать, используя пример, с которым вы связались, и элементы моего ответа :

import numpy as np
from keras import models, layers
from sklearn.datasets import make_classification
from sklearn.model_selection import KFold

np.random.seed(0)

# Number of features
number_of_features = 100

# Generate features matrix and target vector
features, target = make_classification(n_samples = 10000,
                                       n_features = number_of_features,
                                       n_informative = 3,
                                       n_redundant = 0,
                                       n_classes = 2,
                                       weights = [.5, .5],
                                       random_state = 0)

def create_network():
    network = models.Sequential()
    network.add(layers.Dense(units=16, activation='relu', input_shape=(number_of_features,)))
    network.add(layers.Dense(units=16, activation='relu'))
    network.add(layers.Dense(units=1, activation='sigmoid'))

    network.compile(loss='binary_crossentropy', 
                    optimizer='rmsprop', 
                    metrics=['accuracy']) 

    return network

n_splits = 3
kf = KFold(n_splits=n_splits, shuffle=True)

loss = []
acc = []
val_loss = []
val_acc = []

# cross validate:
for train_index, val_index in kf.split(features):
    model = create_network()
    hist = model.fit(features[train_index], target[train_index],
                     epochs=10,
                     batch_size=100,
                     validation_data = (features[val_index], target[val_index]),
                     verbose=0)
    loss.append(hist.history['loss'])
    acc.append(hist.history['acc'])
    val_loss.append([hist.history['val_loss']])
    val_acc.append(hist.history['val_acc'])

После чего, например, loss будет:

[[0.7251979386058971,
  0.6640552306833333,
  0.6190941931069023,
  0.5602273066015956,
  0.48771809028534785,
  0.40796665995284814,
  0.33154681897220617,
  0.2698465999525444,
  0.227492357244586,
  0.1998490962115201],
 [0.7109123742507104,
  0.674812126485093,
  0.6452083222258479,
  0.6074533335751673,
  0.5627432800365635,
  0.51291748379345,
  0.45645068427406726,
  0.3928780094229408,
  0.3282097149542538,
  0.26993170230619656],
 [0.7191790426458682,
  0.6618405645963258,
  0.6253172250296091,
  0.5855853647883192,
  0.5438901918195831,
  0.4999895181964501,
  0.4495182811042725,
  0.3896359298090465,
  0.3210068798340545,
  0.25932698793518183]]

то есть список n_splits списков (здесь 3), каждый из которых содержит потерю обучения для каждой эпохи (здесь 10). Аналогично для других списков ...

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...