Как смоделировать выходную последовательность разной длины с помощью кодера-декодера в керасе? - PullRequest
0 голосов
/ 14 июня 2019

Я пытаюсь построить архитектуру кодера-декодера в Керасе для видео. У меня есть сверточный LSTM, кодер и декодер работают нормально, если у меня вход и выход одинакового размера, однако я не могу заставить его работать в случае, когда размер вывода отличается от длины ввода.

Сам код из другого вопроса stackoverflow, хотя и связанный с автоэнкодерами. Я хотел бы сейчас иметь декодер для целевых последовательностей.

Приведенный ниже код работает нормально, но как я могу изменить конечный вывод, декодированный не длиной 48 шагов, а, например, скажем 20.

input_seq = Input(shape=(48,32,16,1))  

a = ConvLSTM2D(40, (3, 3), activation='sigmoid', padding='same', return_sequences=True)(input_seq)
a = ConvLSTM2D(40, (3, 3), activation='sigmoid', padding='same', return_sequences=True )(a)

b = MaxPooling3D((2,2,2), padding='same')(a)


c = ConvLSTM2D(40, (3, 3), activation='sigmoid', padding='same', return_sequences=True)(b)
c = ConvLSTM2D(40, (3, 3), activation='sigmoid', padding='same',return_sequences=True)(c)

encoded = MaxPooling3D((2,2,2), padding='same', name="encoder")(c)


d = ConvLSTM2D(40, (3, 3), activation='relu', padding='same',return_sequences=True )(encoded)
d = ConvLSTM2D(40, (3, 3), activation='relu', padding='same', return_sequences=True)(d)

e= UpSampling3D((2, 2,2))(d)

##Skip connection
#merge_one = concatenate([b, e])

f = ConvLSTM2D(40, (3, 3), activation='sigmoid', padding='same', return_sequences=True) (e)#(e)
f = ConvLSTM2D(40, (3, 3), activation='sigmoid', padding='same', return_sequences=True) (f)#(e)

g = UpSampling3D((2, 2,2))(f)    

decoded = Conv3D(1, (3, 3, 3), activation='sigmoid', padding='same')(g)# (merge_two)

model = Model(input_seq, decoded)
model.compile(optimizer='adadelta', loss='binary_crossentropy')


model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_4 (InputLayer)         (None, 48, 32, 16, 1)     0         
_________________________________________________________________
conv_lst_m2d_17 (ConvLSTM2D) (None, 48, 32, 16, 40)    59200     
_________________________________________________________________
conv_lst_m2d_18 (ConvLSTM2D) (None, 48, 32, 16, 40)    115360    
_________________________________________________________________
max_pooling3d_3 (MaxPooling3 (None, 24, 16, 8, 40)     0         
_________________________________________________________________
conv_lst_m2d_19 (ConvLSTM2D) (None, 24, 16, 8, 40)     115360    
_________________________________________________________________
conv_lst_m2d_20 (ConvLSTM2D) (None, 24, 16, 8, 40)     115360    
_________________________________________________________________
encoder (MaxPooling3D)       (None, 12, 8, 4, 40)      0         
_________________________________________________________________
conv_lst_m2d_21 (ConvLSTM2D) (None, 12, 8, 4, 40)      115360    
_________________________________________________________________
conv_lst_m2d_22 (ConvLSTM2D) (None, 12, 8, 4, 40)      115360    
_________________________________________________________________
up_sampling3d_5 (UpSampling3 (None, 24, 16, 8, 40)     0         
_________________________________________________________________
conv_lst_m2d_23 (ConvLSTM2D) (None, 24, 16, 8, 40)     115360    
_________________________________________________________________
conv_lst_m2d_24 (ConvLSTM2D) (None, 24, 16, 8, 40)     115360    
_________________________________________________________________
up_sampling3d_6 (UpSampling3 (None, 48, 32, 16, 40)    0         
_________________________________________________________________
conv3d_3 (Conv3D)            (None, 48, 32, 16, 1)     1081      
=================================================================
Total params: 867,801
Trainable params: 867,801
Non-trainable params: 0
_________________________________________________________________
...