Я пытаюсь объединить модель классификации CNN-LSTM, но у меня появляется следующая ошибка:
ValueError: Вход 0 несовместим со слоем flatten_2: ожидается min_ndim = 3, найдено ndim = 2
Среда:
Python 3.5
Keras 2.2.0
Tf-GPU 1.6.0
Любые идеи о том, как я могу решить проблему? Большое спасибо!
from keras.layers import Convolution2D, MaxPooling2D, Flatten, Reshape
from keras.models import Sequential
from keras.utils.np_utils import to_categorical
from keras.layers.wrappers import TimeDistributed
from keras.layers.pooling import GlobalAveragePooling1D
import gc
import numpy as np
timesteps = 100;
number_of_samples = 2500;
nb_samples = number_of_samples;
frame_row = 32;
frame_col = 32;
channels = 3;
nb_epoch = 1;
batch_size = timesteps;
data = np.random.random((2500, timesteps, frame_row, frame_col, channels))
label = np.random.randint(4, size=(2500, 1))
X_train = data[0:2000, :]
y_train = label[0:2000]
y_train = to_categorical(y_train)
X_test = data[2000:, :]
y_test = label[2000:, :]
# %%
model = Sequential();
model.add(TimeDistributed(Convolution2D(32, 3, 3, border_mode='same'), input_shape=(100, 32, 32, 3)))
model.add(TimeDistributed(Convolution2D(32, 3, 3, border_mode='same'), input_shape=(100, 32, 32, 3)))
model.add(TimeDistributed(Activation('relu')))
model.add(TimeDistributed(Convolution2D(32, 3, 3)))
model.add(TimeDistributed(Activation('relu')))
model.add(TimeDistributed(MaxPooling2D(pool_size=(2, 2))))
model.add(TimeDistributed(Dropout(0.25)))
model.add(TimeDistributed(Flatten()))
model.add(TimeDistributed(Dense(512)))
model.add(TimeDistributed(Dense(35, name="first_dense")))
model.add(LSTM(20, return_sequences=True, name="lstm_layer"));
# %%
model.add(TimeDistributed(Dense(4), name="time_distr_dense_one"))
model.add(GlobalAveragePooling1D(name="global_avg"))
model.add(Flatten())
model.add(TimeDistributed(Dense(4, activation="softmax"), name="time_distr_dense"))
# %%
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
model.fit(X_train, y_train, epochs=3, validation_split=0.1, batch_size=32, verbose=2)
gc.collect()