Я реализую категорический dqn для обучения с подкреплением и получаю следующую ошибку:
proj_dist.view (-1) .index_add_ (0, (l + offset) .view (-1), (next_dist * (u.float () - b)). view (-1))
RuntimeError: Ожидаемый объект типа torch.cuda.LongTensor, но найден тип torch.LongTensor для аргумента # 3 'other '
Я уже пытался изменить третий аргумент первого вызова на index_add_
как:
.long().cuda()
long().to(device)
...
ничего не работает
Вот код, где все параметры функции являются тензорами.Обратите внимание, что я работаю с партией 64, так что, например, next_states
будет примерно 64 тензора:
def projection_distribution(self, next_states, rewards, dones):
batch_size = next_states.size(0)
delta_z = float(self.V_max - self.V_min) / (self.num_atoms - 1)
support = torch.linspace(self.V_min, self.V_max, self.num_atoms)
# torch.Size([64, 4, 51])
next_dist = self.qnetwork_target(next_states).data.cpu() *
support.data.cpu()
# torch.Size([64])
next_actions = next_dist.sum(2).max(1)[1]
# torch.Size([64, 1, 51])
next_actions =
next_actions.unsqueeze(1).unsqueeze(1).expand(next_dist.size(0), 1,
next_dist.size(2))
# these are the batch_size distributions relative to the optimal
actions 'next_actions'
# torch.Size([64, 51])
next_dist = next_dist.gather(1, next_actions).squeeze(1)
rewards = rewards.expand_as(next_dist)
print('rewards: ', rewards.size())
dones = dones.expand_as(next_dist)
support = support.unsqueeze(0).expand_as(next_dist)
print('support: ', support.size())
Tz = rewards + (1 - dones) * GAMMA * support.to(device)
Tz = Tz.clamp(min=self.V_min, max=self.V_max)
b = (Tz - self.V_min) / delta_z
l = b.floor().long()
u = b.ceil().long()
print('TYPE: ' , b.type(), l.type(), u.type(), next_dist.type())
offset = torch.linspace(0, (batch_size - 1) * self.num_atoms, batch_size).long()\
.unsqueeze(1).expand(batch_size, self.num_atoms)
print('offset: ', offset.size())
proj_dist = torch.zeros(next_dist.size())
print('proj_dist_1: ', proj_dist.size())
proj_dist.view(-1).index_add_(0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1))
print('proj_dist_2: ', proj_dist.size())
proj_dist.view(-1).index_add_(0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1))
print('proj_dist_3: ', proj_dist.size())
return proj_dist
Почему я получаю эту ошибку?Как мне это исправить?
Большое спасибо Макс