Я пытаюсь вычислить текущий Q для глубокого обучения Q, но сталкиваюсь с этой ошибкой:
RuntimeError: invalid argument 4: Index tensor must have same dimensions as input tensor at C:/w/1/s/windows/pytorch/aten/src\THC/generic/THCTensorScatterGather.cu:16
Код, вызывающий эту ошибку:
curr_Q = self.model.forward(states).gather(1, actions.unsqueeze(1))
Форма self.model.forward(states)
- это [32, 640, 10], а форма действий - [32].
Остальная часть кода в этой функции находится здесь:
def compute_loss(self, batch):
states, actions, rewards, next_states, dones = batch
states = torch.FloatTensor(states).to(self.device)
actions = torch.LongTensor(actions).to(self.device)
rewards = torch.FloatTensor(rewards).to(self.device)
next_states = torch.FloatTensor(next_states).to(self.device)
dones = torch.FloatTensor(dones)
curr_Q = self.model.forward(states).gather(1, actions.unsqueeze(1))
curr_Q = curr_Q.squeeze(1)
next_Q = self.model.forward(next_states)
max_next_Q = torch.max(next_Q, 1)[0]
expected_Q = rewards.squeeze(1) + self.gamma * max_next_Q
loss = self.MSE_loss(curr_Q, expected_Q)
return loss