керас, классификация MNIST с моделью RNN, проблема с формой вывода - PullRequest
0 голосов
/ 01 марта 2019

Я пытаюсь использовать функциональный API-интерфейс keras для построения рекуррентной нейронной сети, но у меня возникли проблемы с формой вывода, любая помощь будет оценена.

мой код:

import tensorflow as tf
from tensorflow.python.keras.datasets import mnist
from tensorflow.python.keras.layers import Dense, CuDNNLSTM, Dropout
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.utils import normalize
from tensorflow.python.keras.utils import np_utils

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = normalize(x_train, axis=1), normalize(x_test, axis=1)

y_train = np_utils.to_categorical(y_train, 10)
y_test = np_utils.to_categorical(y_test, 10)

feature_input = tf.keras.layers.Input(shape=(28, 28))
x = tf.keras.layers.CuDNNLSTM(128, kernel_regularizer=tf.keras.regularizers.l2(l=0.0004), return_sequences=True)(feature_input)
y = tf.keras.layers.Dense(10, activation='softmax')(x)
model = tf.keras.Model(inputs=feature_input, outputs=y)
opt = tf.keras.optimizers.Adam(lr=1e-3, decay=1e-5)
model.compile(optimizer=opt, loss="sparse_categorical_crossentropy", metrics=['accuracy'])
model.fit(x_train, y_train, epochs=3, validation_data=(x_test, y_test))

ОШИБКА:

ValueError: Ошибка при проверке цели: ожидается, что плотный имеет 3 измерения, но получил массив с формой (60000, 10)

1 Ответ

0 голосов
/ 02 марта 2019

Ваши данные (цели) имеют форму (60000, 10).

Выход вашей модели («плотный») имеет форму (None, length, 10).

Где None - размер партии (переменная), length - среднее измерение, что означает «время».шаги "для LSTM, а 10 - это единицы уровня Dense.

Теперь у вас нет последовательности шагов времени для обработки в LSTM, это не имеет смысла.Он интерпретирует «строки изображения» как последовательные временные шаги, а «столбцы изображения» как независимые элементы.(Если это не было вашим намерением, вам просто повезло, что это не дало вам ошибку при попытке вставить изображение в LSTM)

В любом случае, вы можете исправить эту ошибку с помощью return_sequences=False (отменитьlength последовательностей).Что не означает, что эта модель является оптимальной для этого случая.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...