Я относительно новичок в сетях Convolutional LSTM, но в настоящее время я работаю над проблемой, которая предполагает прогнозирование последовательности кадров в будущем, поэтому я решил изучить сети ConvLSTM.
Чтобы понять, как работает модель и как я могу ее расширить, я попробовал несколько начальных тестов для набора данных Moving MNIST, доступных здесь:
http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy
Однако, после обучения и выполнения вывода - я бы подумал, что предсказания будут более связными, особенно по сравнению с другими людьми, которые использовали аналогичные подходы для набора данных Moving MNIST. Кажется, что исходная траектория цифр «сохраняется» в выводе.
Это общее ограничение или моя сетевая архитектура неправильно спроектирована для данной задачи?
Настройка
Я прочитал и применил код, приведенный в следующей статье:
https://arxiv.org/abs/1506.04214
и у них также есть страница Github, где я в основном использовал их пример keras для ячейки ConvLSTM:
https://github.com/wqxu/ConvLSTM
Я уменьшил размер выборки до 100, чтобы вы могли воспроизвести результаты, но я обучил модель с графическим процессором K40 для 100 эпох (занимает около часа), чтобы посмотреть, не связана ли проблема с моделью, просто не сходящейся .
Мой код выглядит следующим образом (при условии, что вы загрузили набор данных Moving MNIST по указанной выше ссылке и поместили его в переменную 'path'):
from keras.models import Sequential
from keras.layers.convolutional import Conv3D
from keras.layers.convolutional_recurrent import ConvLSTM2D
from keras.layers.normalization import BatchNormalization
import numpy as np
import matplotlib.pyplot as plt
path = "./"
data = np.load(path + 'mnist_test_seq.npy')
# Define image dimensions and frames to be used for LSTM memory
sequence_length = 15
image_height = data.shape[2]
image_width = data.shape[3]
# swap frames and observations so [obs, frames, height, width, channels]
data = data.swapaxes(0, 1)
# only select first 100 observations to reduce memory- and compute requirements
sub = data[:100, :, :, :]
# add channel dimension (grayscale)
sub = np.expand_dims(sub, 4)
# normalize to 0, 1
#sub = sub / 255
sub[sub < 128] = 0
sub[sub>= 128] = 1
# Define network
seq = Sequential()
seq.add(ConvLSTM2D(filters=64, kernel_size=(1,1),
input_shape=(None, image_height, image_width, 1), #Will need to change channels to 3 for real images
padding='same', return_sequences=True,
activation='relu'))
seq.add(BatchNormalization())
seq.add(ConvLSTM2D(filters=64, kernel_size=(2,2),
padding='same', return_sequences=True,
activation='relu'))
seq.add(BatchNormalization())
seq.add(ConvLSTM2D(filters=64, kernel_size=(1,1),
padding='same', return_sequences=True,
activation='relu'))
seq.add(BatchNormalization())
seq.add(ConvLSTM2D(filters=64, kernel_size=(2,2),
padding='same', return_sequences=True,
activation='relu'))
seq.add(BatchNormalization())
seq.add(Conv3D(filters=1, kernel_size=(1,1,1),
activation='sigmoid',
padding='same', data_format='channels_last'))
seq.compile(loss='binary_crossentropy', optimizer='adam')
# Add helper function for shifting input and output, so previous frame (X_t-1) is used as input to predict next frame (y_t)
def shift_data(data, n_frames=15):
X = data[:, 0:n_frames, :, :, :]
y = data[:, 1:(n_frames+1), :, :, :]
return X, y
# Run script
# prepare X, y
X, y = shift_data(sub, sequence_length)
# fit the model
seq.fit(X, y, batch_size=16, epochs=100, validation_split=0.05)
# select a random observation
test_set = np.expand_dims(X[5, :, :, :, :], 0)
# compare to ground truth and visualize
for i in range(0, 13):
# create plot
fig = plt.figure(figsize=(10, 5))
# truth
ax = fig.add_subplot(122)
ax.text(1, -3, ('ground truth at time :' + str(i)), fontsize=20, color='b')
toplot_true = test_set[0, i, ::, ::, 0]
plt.imshow(toplot_true)
# predictions
ax = fig.add_subplot(121)
ax.text(1, -3, ('predicted frame at time :' + str(i)), fontsize=20, color='b')
toplot_pred = prediction[0, i+1, ::, ::, 0]
plt.imshow(toplot_pred)
plt.savefig(path + '/%i_image.png' % (i + 1))
Я получаю следующие результаты:
Первое изображение выглядит довольно хорошо
Рамка 1
Однако Кадр 6 и Кадр 13 Кадр четко показывает вам всю траекторию предыдущих шагов.
И если вы визуализируете все изображения одновременно, становится ясно, что траектория цифр никогда не "удаляется" из изображения.
Я не уверен, является ли это просто некоторым известным ограничением модели или это не просто сходящаяся модель. Меня беспокоит то, что эти результаты не слишком удовлетворительны, учитывая относительную простоту набора данных, и что более сложные задачи будут просто невозможны для модели.
Любая обратная связь будет принята с благодарностью!