Как связать сверточный слой с LSTM в тензорном потоке Кераса - PullRequest
0 голосов
/ 30 марта 2019

Я экспериментирую с архитектурой нейронной сети и пытаюсь подключить 2D-свертку к ячейке LSTM в тензорном потоке Keras.

Вот моя оригинальная модель:

model = Sequential()

model.add(CuDNNLSTM(256, input_shape=(train_x.shape[1:]), return_sequences=True))
model.add(Dropout(0.2))
model.add(BatchNormalization())

model.add(Dense(64, activation='relu'))
model.add(Dropout(0.2))

model.add(Dense(4, activation='softmax'))

Работает как магия.

train_x - это 1209 последовательностей, каждый набор имеет 23 номера, а последовательность имеет длину 128. Другими словами - его форма (1209, 128, 23). Вход для модели равен train_x.shape [1:] = (128,23).

Теперь я добавляю 256 плотных слоев перед ячейкой LSTM, изменяю их размер до размера 16x16, добавляем двухмерное свертывание, выравниваю и соединяю его с ячейкой LSTM. (и то же самое для слоев, следующих за ячейкой LSTM).

Я начал с:

model = Sequential()

model.add(Dense(256, input_shape=(train_x.shape[1:]), activation='relu'))
model.add(Dropout(0.2))

model.add(Reshape((16, 16), input_shape=(256,)))
model.add(Conv2D(16, (5,5), activation='relu', padding='same', input_shape=(16,16,1)))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())

model.add(CuDNNLSTM(256, return_sequences=True))
model.add(Dropout(0.1))
model.add(BatchNormalization())

model.add(Dense(64, activation='relu'))
model.add(Dropout(0.2))

model.add(Dense(4, activation='softmax'))

У меня есть две ошибки:

Input 0 of layer conv2d is incompatible with the layer: expected ndim=4, found ndim=3. Full shape received: [None, 16, 16]

А когда я удаляю свертку и оставляю только слои Reshape и Flatten:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 4194304 values, but the requested shape has 32768
 [[{{node reshape/Reshape}} = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _class=["loc:@training/Adam/gradients/dropout/cond/Merge_grad/cond_grad"], _device="/job:localhost/replica:0/task:0/device:GPU:0"](dropout/cond/Merge, reshape/Reshape/shape)]]

Вы знаете, как с этим бороться?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...