Я новичок в обучении с подкреплением и пытался использовать LSTM для обучения с подкреплением для агента космических захватчиков. Я попытался использовать сеть, найденную в этой статье , но у меня возникли проблемы:
-Если я использую conv2D, размеры с LSTM не совпадают, и я получаю эту ошибку:
ValueError: Вход 0 несовместим со слоем conv_lst_m2d_1: ожидаемый ndim = 5, найденный ndim = 4
Это код:
self.model = Sequential()
self.model.add(Conv2D(32,kernel_size=8,strides=4,activation='relu',input_shape=(None,84,84,1)))
self.model.add(Conv2D(64,kernel_size=4,strides=2,activation='relu'))
self.model.add(Conv2D(64,kernel_size=3, strides=1,activation='relu'))
self.model.add(ConvLSTM2D(512, kernel_size=(3,3), padding='same', return_sequences=False))
self.model.add(Dense(4, activation='relu'))
self.model.compile(loss='mse', optimizer=Adam(lr=0.0001))
self.model.summary()
-И если Я использую Conv3D, который выводит тензор 5D. Я не могу использовать одно изображение в качестве ввода:
ValueError: Ошибка при проверке ввода: ожидалось, что conv3d_1_input имеет 5 измерений, но получил массив с формой (1, 84, 84, 1)
Код:
self.model.add(Conv3D(32,kernel_size=8,strides=4,activation='relu',input_shape=(None,84,84,1)))
self.model.add(Conv3D(64,kernel_size=4,strides=2,activation='relu'))
self.model.add(Conv3D(64,kernel_size=3, strides=1,activation='relu'))
self.model.add(ConvLSTM2D(512, kernel_size=(3,3), padding='same', return_sequences=False))
self.model.add(Dense(4, activation='relu'))
self.model.compile(loss='mse', optimizer=Adam(lr=0.0001))
self.model.summary()
(редактировать)
Сводная информация о сети (второй сети):
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv3d_1 (Conv3D) (None, None, 20, 20, 32) 16416
_________________________________________________________________
conv3d_2 (Conv3D) (None, None, 9, 9, 64) 131136
_________________________________________________________________
conv3d_3 (Conv3D) (None, None, 7, 7, 64) 110656
_________________________________________________________________
conv_lst_m2d_1 (ConvLSTM2D) (None, 7, 7, 512) 10618880
_________________________________________________________________
dense_1 (Dense) (None, 7, 7, 4) 2052
=================================================================
И форма ввода данных: (84, 84, 1)