Несовместимые формы в функции керас - PullRequest
0 голосов
/ 26 марта 2019

Я пытаюсь реализовать сеть критиков-актеров, используя keras (и тензор потока 2.0 альфа с отключенным нетерпеливым выполнением), но, похоже, ошибка в функции keras, которая обновляет вес сети актера.

Я оставляю свои операторы print () в коде, чтобы показать то, что я уже исследовал, и, надеюсь, восполнить тот факт, что этот код неполон и, следовательно, не воспроизводим. Редактировать: Вот моя модель актера и критика

enter image description here

Функция вызывается следующим образом, и я выводю формы всех входных массивов:

# Networks optimization
print('Shapes of vars: states: {}, actions: {}, advantages: {}'.format(
    np.array(states).shape, np.array(actions).shape, np.array(advantages).shape))
self.a_opt([states, actions, advantages]) # call the keras function written out below
# a print statement here is never reached

Вызываемая функция (которая выводит ошибку «несовместимые формы») выглядит следующим образом:

def a_opt(self):
    """ Actor Optimization: Advantages + Entropy term to encourage exploration
    (Cf. https://arxiv.org/abs/1602.01783)
    """
    modelout = K.print_tensor(
        self.model.output, message="model output: " + str(K.int_shape(self.model.output)))
    action_pl = K.print_tensor(
        self.action_pl, message="action_pl: " + str(K.int_shape(self.action_pl)))

    weighted_actions = K.sum(action_pl * modelout, axis=1)
    weighted_actions = K.print_tensor(
        weighted_actions, message="weighted_actions: ")

    eligibility = K.log(weighted_actions + 1e-10) * \
        K.stop_gradient(self.advantages_pl)
    eligibility = K.print_tensor(eligibility, message="eligibility: ")

    entropy = K.sum(modelout *
                    K.log(modelout + 1e-10), axis=1)
    entropy = K.print_tensor(entropy, message="entropy: ")

    loss = 0.001 * entropy - K.sum(eligibility)
    loss = K.print_tensor(loss, message="loss: ")

    updates = self.rms_optimizer.get_updates(loss=loss,
                                             params=self.model.trainable_weights)
    return K.function([self.model.input, self.action_pl, self.advantages_pl], [], updates=updates)

Пока все хорошо, но выполнение программы дает следующий вывод консоли:

Shapes of vars: states: (999, 1, 44), actions: (999, 3), advantages: (999,)
action_pl: (None, 3)[[0.861626744 0.928109825 0.0259102583...]...]
model output: (None, 3)[[0.365334 0.333090335 0.301575601]]
Traceback (most recent call last):
  File ".\actor_critic.py", line 85, in <module>
    agent.train(marketSim, ac_args, summary_writer)
  File "FILEPATH", line 115, in train
    self.train_models(states, actions, rewards, done)
  File "FILEPATH", line 76, in train_models
    self.a_opt([states, actions, advantages])
  File "C:\Python37\lib\site-packages\tensorflow\python\keras\backend.py", line 3096, in __call__
    run_metadata=self.run_metadata)
  File "C:\Python37\lib\site-packages\tensorflow\python\client\session.py", line 1440, in __call__
    run_metadata_ptr)
  File "C:\Python37\lib\site-packages\tensorflow\python\framework\errors_impl.py", line 548, in __exit__
    c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [1,44] vs. [1,3]
         [[{{node gradients/mul_grad/BroadcastGradientArgs}}]]
         [[Sum_1/_119]]

Как видите, размер партии равен 999, state имеет форму (1,44) и actions имеют форму (3,). Из сообщения об ошибке я предполагаю, что я где-то умножаю эти два, но не могу найти, где это происходит. Я также не понимаю, почему action_pl и model_output имеют одинаковую форму (None, 3), но, хотя для model_output это, очевидно, правильно, распечатанный тензор action_pl выглядит так, как будто он может иметь другую форму (1,44, может быть?), Что меня полностью смущает, поскольку список, который я передаю функции как actions, определенно имеет форму (999, 3)

Также я не уверен, какая строка на самом деле вызывает ошибку: в соответствии со строками print_tensor я бы предположил, что строка weighted_actions = K.sum..., потому что ничто ниже этого (или глубже в графе вычислений) ничего не выводит, но это может быть неправильно.

tl; dr: Какая строка a_opt() на самом деле приводит к ошибке, откуда берется форма [1,44] в ошибке и есть ли лучший способ для отладки графов вычислений, подобных этой?

...