«Функция вызова не соответствует сигнатуре базового метода в классе« Сеть »» - построение модели A3C - PullRequest
0 голосов
/ 20 октября 2019

Я пытаюсь протестировать A3C-модель, которую мне предоставили в этом руководстве: https://medium.com/tensorflow/deep-reinforcement-learning-playing-cartpole-through-asynchronous-advantage-actor-critic-a3c-7eab2eea5296

Во-первых, поскольку алгоритм не распознает все пакеты (хотя и установлены), я попыталсяссылаются на них по-разному:

  • из тензорного потока импорта керас (не работает)
  • из тензорного потока._api.v1 импорт кераса (работает)

Теперь у меня есть следующая базовая концепция A3C-модели:

class ActorCriticModel (keras.Model):

def __init__(self, state_size, action_size):                  
    super(ActorCriticModel, self).__init__()
    self.state_size = state_size
    self.action_size = action_size
    self.dense1 = layers.Dense(100, activation='relu')
    self.policy_logits = layers.Dense(action_size)
    self.dense2 = layers.Dense(100, activation='relu')
    self.values = layers.Dense(1)

def call(self, inputs):
    # Forward pass
    x = self.dense1(inputs)
    logits = self.policy_logits(x)

    v1 = self.dense2(inputs)
    values = self.values(v1)
    return logits, values

model = ActorCriticModel ()

Предоставление пакетов:

  • из слоев импорта keras (который распознается как использованный пакет)
  • из ввода импорта keras.layers (не распознается как использованный пакет)

В настоящее время я сталкиваюсь с проблемой того, что входные данные в моей модели не "совпадают с сигнатурой базового метода в классе 'Network'".

Как мне решить эту проблему?

О программеверсии программ и API, с которыми я работаю:

  • Python 3.7.3
  • Tensorflow 1.13.1
  • Керас 2.3.1
...