Я получил NaN в функции потери Keras с первой эпохи - PullRequest
0 голосов
/ 11 февраля 2020

Я получил потерю NaN с первой эпохи. Форма train_data (891,13). Форма train_labels - (891,2). Я создаю эту модель для конкурса Titani c в Kaggle.


from keras import models
from keras import layers
import tensorflow as tf


def build_model():
    model = models.Sequential()
    model.add(layers.Dense(64, activation='relu', input_shape=(train_data.shape[1],), kernel_initializer='normal', bias_initializer='zeros'))
    model.add(layers.Dropout(0.5))
    model.add(layers.Dense(64, activation='relu'))
    model.add(layers.Dropout(0.5))
    model.add(layers.Dense(64, activation='relu'))
    model.add(layers.Dropout(0.5))
    model.add(layers.Dense(2, activation='sigmoid'))
    model.compile(loss='categorical_crossentropy',
                 optimizer='Adam',
                 metrics=['accuracy'])
    return model
k = 3
num_val_samples = len(train_data) // k
num_epochs = 100
all_scores = []
for i in range(k):
    print('processing fold #', i)
    #検証データの準備
    val_data = train_data[i * num_val_samples: (i+1) * num_val_samples]
    val_labels = train_labels[i * num_val_samples: (i+1) * num_val_samples]
    #訓練データの準備
    partial_train_data = np.concatenate([train_data[:i * num_val_samples], train_data[(i+1) * num_val_samples:]], axis=0)
    partial_train_labels = np.concatenate([train_labels[:i * num_val_samples], train_labels[(i+1) * num_val_samples:]], axis=0)

model = build_model()
history = model.fit(partial_train_data,
                    partial_train_labels,
                    epochs=num_epochs,
                    batch_size=1,
                   validation_data=(val_data,val_labels))

1 Ответ

0 голосов
/ 11 февраля 2020

Возможно, проблема связана с определением вашей модели:

Не используйте Sigmoid активацию с потерей categorical_crossentropy:

  • Двоичная проблема: используйте Sigmoid с binary_crossentropy
  • Проблема с несколькими классами: используйте Softmax с categorical_crossentropy

РЕДАКТИРОВАТЬ

Еще несколько подсказок:

  • Проверьте свои данные, возможно, вы вводите поврежденное значение в свою сеть, проверьте наличие NaN в наборе данных
  • Попробуйте другие оптимизаторы
  • Попробуйте масштабировать данные или попытаться изменить их по-другому
  • Попробуйте увеличить размер партии
...