Как передать мульти вход в train_on_batch в Керасе - PullRequest
0 голосов
/ 25 апреля 2019
 ValueError: could not broadcast input array from shape (60,60,2) into shape (1)

Я пытался каким-то образом изменить мой код, но все еще с той же ошибкой.

  1. state.append (np.array (s)) # mark 1 target_f_list.append (np.array (target_f)) # mark 2
  2. self.model.train_on_batch ([состояние], [target_f_list]) # mark 3
  3. self.model.train_on_batch (np.array (состояние), np.array (target_f_list)) # mark 3

Это моя сеть Керас:

    input_1 = Input(shape=(60, 60, 2))
    input_2 = Input(shape=(self.action_size, self.action_size))
    x1 = Conv2D(32, (4, 4), strides=(2, 2), padding='Same', activation=LeakyReLU(alpha=self.Beta))(input_1)
    x1 = Conv2D(64, (2, 2), strides=(2, 2), padding='Same', activation=LeakyReLU(alpha=self.Beta))(x1)
    x1 = Conv2D(128, (2, 2), strides=(1, 1), padding='Same', activation=LeakyReLU(alpha=self.Beta))(x1)
    x1 = Flatten()(x1)
    x1 = Dense(128, activation=LeakyReLU(alpha=self.Beta))(x1)
    x1_value = Dense(64, activation=LeakyReLU(alpha=self.Beta))(x1)
    value = Dense(1, activation=LeakyReLU(alpha=self.Beta))(x1_value)
    x1_advantage = Dense(64, activation=LeakyReLU(alpha=self.Beta))(x1)
    advantage = Dense(self.action_size, activation=LeakyReLU(alpha=self.Beta))(x1_advantage)

    A = Dot(axes=1)([input_2, advantage])
    A_subtract = Subtract()([advantage, A])

    Q_value = Add()([value, A_subtract])

    model = Model(inputs=[input_1, input_2], outputs=[Q_value])
    model.compile(optimizer=Adam(lr=self.epsilon_r), loss='mse')

Это моя функция тренироваться:

    state = []
    target_f_list = []
    for s, a, r, next_s, done in minibatch:
        if not done:

            ... do calculate target_f ...

            state.append(s)                   # mark 1
            target_f_list.append(target_f)    # mark 2

            # this is fit function i use before and it's worked fine. But i want to train all minibatch add the same time.
            # self.model.fit(s, target_f, epochs=1, verbose=0, batch_size=self.minibatch_size)

    # This is my code has error
    self.model.train_on_batch(state,target_f_list)  # mark 3

Спасибо, что прочитали мой вопрос.

...