keras - Ошибка при проверке цели со слоем встраивания - PullRequest
0 голосов
/ 27 января 2020

Я пытаюсь запустить модель keras следующим образом:

model = Sequential()
model.add(Dense(10, activation='relu',input_shape=(286,)))
model.add(Dense(1, activation='softmax',input_shape=(324827, 286)))

Этот код работает, но если я пытаюсь добавить слой для встраивания:

model = Sequential()
model.add(Embedding(286,64, input_shape=(286,)))
model.add(Dense(10, activation='relu',input_shape=(286,)))
model.add(Dense(1, activation='softmax',input_shape=(324827, 286)))

I ' получаю следующую ошибку:

ValueError: Error when checking target: expected dense_2 to have 3 dimensions, but got array with shape (324827, 1)

Мои данные имеют 286 объектов и 324827 строк. Я, наверное, что-то не так с определениями формы, можете ли вы сказать мне, что это такое? Спасибо

1 Ответ

0 голосов
/ 27 января 2020

Вам не нужно указывать input_shape во втором плотном слое, и ни в первом, только в первом слое будет вычислена следующая форма слоев:

from tensorflow.keras.layers import Embedding, Dense
from tensorflow.keras.models import Sequential

# 286 features and 324827 rows (324827, 286)

model = Sequential()
model.add(Embedding(286,64, input_shape=(286,)))
model.add(Dense(10, activation='relu'))
model.add(Dense(1, activation='softmax'))
model.compile(loss='mse', optimizer='adam')
model.summary()

возвращает:

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding_2 (Embedding)      (None, 286, 64)           18304     
_________________________________________________________________
dense_2 (Dense)              (None, 286, 10)           650       
_________________________________________________________________
dense_3 (Dense)              (None, 286, 1)            11        
=================================================================
Total params: 18,965
Trainable params: 18,965
Non-trainable params: 0
_________________________________________________________________

Надеюсь, это то, что вы ищете

...