Классификация MNIST: функция потерь mean_squared_error и функция активации tanh - PullRequest
0 голосов
/ 19 ноября 2018

Я изменил пример начала работы Tensorflow следующим образом:

import tensorflow as tf
from sklearn.metrics import roc_auc_score
import numpy as np
import commons as cm
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sn

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(512, activation=tf.nn.tanh),
  # tf.keras.layers.Dense(512, activation=tf.nn.tanh),
  # tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation=tf.nn.tanh)
])
model.compile(optimizer='adam',
               loss='mean_squared_error',
              # loss = 'sparse_categorical_crossentropy',
              metrics=['accuracy'])

history = cm.Histories()
h= model.fit(x_train, y_train, epochs=50, callbacks=[history])
print("history:", history.losses)
cm.plot_history(h)
# cm.plot(history.losses, history.aucs)


test_predictions = model.predict(x_test)


# Compute confusion matrix
pred = np.argmax(test_predictions,axis=1)
pred2 = model.predict_classes(x_test)
confusion = confusion_matrix(y_test, pred)
cm.draw_confusion(confusion,range(10))

со своими параметрами по умолчанию:

  • relu активация на скрытых слоях,
  • softmax на выходном слое и
  • sparse_categorical_crossentropy как функция потерь,

работает нормально, и прогноз для всех цифр выше 99%

Однако с моими параметрами: tanh функция активации и mean_squared_error функция потерь она просто предсказывает 0 для всех тестовых образцов:

enter image description here

Интересно, в чем проблема?Точность увеличивается для каждой эпохи и достигает 99%, а потери составляют около 20

1 Ответ

0 голосов
/ 19 ноября 2018

Вам необходимо использовать правильную функцию потерь для ваших данных.Здесь у вас есть категориальный вывод, поэтому вам нужно использовать sparse_categorical_crossentropy, но также установить from_logits без какой-либо активации для последнего слоя.

Если вам нужно использовать tanh в качестве вывода, то выможете использовать MSE с версией ваших ярлыков в горячем коде + масштабирование.

...