Проблема с использованием последовательной модели Keras для пакета «inclearlearn »в R - PullRequest
1 голос
/ 01 апреля 2020

Я пытаюсь использовать нейронную сеть / последовательную модель keras (версия 2.2.50) для создания простого агента в условиях обучения подкреплению с использованием пакета reinforcelearn (версия 0.2.1) в соответствии с этой виньеткой: https://cran.r-project.org/web/packages/reinforcelearn/vignettes/agents.html. Вот код, который я использую:

library('reinforcelearn')
library('keras')

model = keras_model_sequential() %>% 
  layer_dense(units = 10, input_shape = 4, activation = "linear") %>%
  compile(optimizer = optimizer_sgd(lr = 0.1), loss = "mae")

agent = makeAgent(policy = "softmax", val.fun = "neural.network", algorithm = "qlearning",
                  val.fun.args = list(model= model))

Однако, когда я пытаюсь запустить функцию makeAgent, я получаю следующее сообщение об ошибке:

Error in .subset2(public_bind_env, "initialize")(...) : 
  Assertion on 'model' failed: Must inherit from class 'keras.models.Sequential', but has classes 'keras.engine.sequential.Sequential','keras.engine.training.Model','keras.engine.network.Network','keras.engine.base_layer.Layer','tensorflow.python.module.module.Module','tensorflow.python.training.tracking.tracking.AutoTrackable','tensorflow.python.training.tracking.base.Trackable','python.builtin.object'.

Кажется, проблема заключается в неправильный класс модели, но что я мог сделать, чтобы решить эту проблему?

1 Ответ

0 голосов
/ 06 апреля 2020

Мне удалось решить проблему, загрузив исходный код из CRAN (https://cran.r-project.org/src/contrib/reinforcelearn_0.2.1.tar.gz) и закомментировав соответствующую строку в определении функции ValueNetwork R6 class / initialise:

ValueNetwork = R6::R6Class("ValueNetwork",
  public = list(
    model = NULL,

    # keras model # fixme: add support for mxnet
    initialize = function(model) {
      # checkmate::assertClass(model, "keras.models.Sequential")
      self$model = model
    },
...

Затем я просто переустановил пакет из источника через: install.packages([file path], repos = NULL, type="source")

...