Функция сбора Pytorch для получения текущих значений Q для Deep Q - PullRequest
0 голосов
/ 13 февраля 2020

Я пытаюсь вычислить текущий Q для глубокого обучения Q, но сталкиваюсь с этой ошибкой:

RuntimeError: invalid argument 4: Index tensor must have same dimensions as input tensor at C:/w/1/s/windows/pytorch/aten/src\THC/generic/THCTensorScatterGather.cu:16

Код, вызывающий эту ошибку:

curr_Q = self.model.forward(states).gather(1, actions.unsqueeze(1))

Форма self.model.forward(states) - это [32, 640, 10], а форма действий - [32].

Остальная часть кода в этой функции находится здесь:

 def compute_loss(self, batch):
    states, actions, rewards, next_states, dones = batch
    states = torch.FloatTensor(states).to(self.device)
    actions = torch.LongTensor(actions).to(self.device)
    rewards = torch.FloatTensor(rewards).to(self.device)
    next_states = torch.FloatTensor(next_states).to(self.device)
    dones = torch.FloatTensor(dones)

    curr_Q = self.model.forward(states).gather(1, actions.unsqueeze(1))
    curr_Q = curr_Q.squeeze(1)
    next_Q = self.model.forward(next_states)
    max_next_Q = torch.max(next_Q, 1)[0]
    expected_Q = rewards.squeeze(1) + self.gamma * max_next_Q

    loss = self.MSE_loss(curr_Q, expected_Q)
    return loss
...