Я определил довольно простую сеть MLP с потерей стиля градиента политики и получаю довольно странную ошибку, когда tf.tile
при вычислении градиента задается неправильное число multiples
.
Вот ошибка:
2019-10-06 11:48:25.535482: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Invalid argument: Expected multiples argument to be a vector of length 1 but got length 2
[[{{node gradients/loss/lambda_1_loss/policy_gradient_loss/weighted_loss/Sum_grad/Tile}}]]
0%| | 0/100000 [00:01<?, ?it/s]
Traceback (most recent call last):
File "reinforce.py", line 208, in <module>
main(sys.argv)
File "reinforce.py", line 196, in main
reward, loss = policy.train(env, gamma=gamma)
File "reinforce.py", line 36, in train
loss = self.network.train_step(states, actions, discounted_rewards)
File "reinforce.py", line 129, in train_step
history = self.model.fit([states, actions_one_hot], discounted_rewards, verbose=1)
File "/spin/scr/rl_hw3/rl/lib/python3.7/site-packages/keras/engine/training.py", line 1239, in fit
validation_freq=validation_freq)
File "/spin/scr/rl_hw3/rl/lib/python3.7/site-packages/keras/engine/training_arrays.py", line 196, in fit_loop
outs = fit_function(ins_batch)
File "/spin/scr/rl_hw3/rl/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py", line 3740, in __call__
outputs = self._graph_fn(*converted_inputs)
File "/spin/scr/rl_hw3/rl/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 1081, in __call__
return self._call_impl(args, kwargs)
File "/spin/scr/rl_hw3/rl/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 1121, in _call_impl
return self._call_flat(args, self.captured_inputs, cancellation_manager)
File "/spin/scr/rl_hw3/rl/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 1224, in _call_flat
ctx, args, cancellation_manager=cancellation_manager)
File "/spin/scr/rl_hw3/rl/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 511, in call
ctx=ctx)
File "/spin/scr/rl_hw3/rl/lib/python3.7/site-packages/tensorflow_core/python/eager/execute.py", line 67, in quick_execute
six.raise_from(core._status_to_exception(e.code, message), None)
File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected multiples argument to be a vector of length 1 but got length 2
[[node gradients/loss/lambda_1_loss/policy_gradient_loss/weighted_loss/Sum_grad/Tile (defined at /spin/scr/rl_hw3/rl/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py:1751) ]] [Op:__inference_keras_scratch_graph_3802]
Function call stack:
keras_scratch_graph
Сеть определяется как
self.input_dimension = state_dimension
self.output_dimension = num_actions
self.layers = [16, 16, 16, self.output_dimension]
self.initializer = VarianceScaling(scale=3., mode="fan_avg", distribution="uniform")
inputs = K.Input(shape=(self.input_dimension,), name='states')
prev = inputs
for i, layer in enumerate(self.layers):
if i + 1 != len(self.layers):
prev = BatchNormalization()(Dense(units=layer,
kernel_initializer=self.initializer,
bias_initializer='zeros',
activation='relu')(prev))
else:
prev = Dense(units=layer,
activation='softmax',
kernel_initializer=self.initializer,
bias_initializer='zeros',
name='probs')(prev)
actions_one_hot = K.Input(shape=(self.output_dimension,), name='one_hot_actions')
action_probs = Multiply(name='action_probs')([actions_one_hot, prev])
log_action_probs = Lambda(lambda x: B.log(B.sum(x, axis=1)))(action_probs)
self.model = Model(inputs=[inputs, actions_one_hot], outputs=log_action_probs)
self.model.compile(loss=policy_gradient_loss, optimizer=K.optimizers.Adam(learning_rate=self.learning_rate))
self.pred_model = Model(inputs=[inputs], outputs=self.model.get_layer('probs').output)
print(self.model.summary())
, а функция потерь -
def policy_gradient_loss(discounted_rewards, log_action_probs):
# return -B.mean(log_action_probs * discounted_rewards)
return -Multiply()([log_action_probs, discounted_rewards])
Я попытался добавить / удалить одноэлементные измерения для входовно верь, что форма правильная. Если я уберу зависимость от discounted_rewards
в функции потерь, она вычислит градиенты и поезда, но мне, очевидно, нужно будет их включить. Я также воспроизвел ошибку с TF 1.14, так что я не думаю, что это регрессия. Любой совет?
РЕДАКТИРОВАТЬ: выяснил - вывод формы лямбда-слоев, кажется, не работает, заставляя форму быть такой, какой она должна быть, решает проблему, например
log_action_probs = Lambda(lambda x: B.log(B.sum(x, axis=1)), output_shape=(None,))(action_probs)