pytorch RuntimeError: Ожидаемый объект типа torch.cuda.LongTensor, но для аргумента найден тип torch.LongTensor - PullRequest
0 голосов
/ 28 января 2019

Я реализую категорический dqn для обучения с подкреплением и получаю следующую ошибку:

proj_dist.view (-1) .index_add_ (0, (l + offset) .view (-1), (next_dist * (u.float () - b)). view (-1))

RuntimeError: Ожидаемый объект типа torch.cuda.LongTensor, но найден тип torch.LongTensor для аргумента # 3 'other '

Я уже пытался изменить третий аргумент первого вызова на index_add_ как:

  • .long().cuda()
  • long().to(device)...

ничего не работает

Вот код, где все параметры функции являются тензорами.Обратите внимание, что я работаю с партией 64, так что, например, next_states будет примерно 64 тензора:

def projection_distribution(self, next_states, rewards, dones):
        batch_size  = next_states.size(0)

        delta_z = float(self.V_max - self.V_min) / (self.num_atoms - 1)
        support = torch.linspace(self.V_min, self.V_max, self.num_atoms)

        # torch.Size([64, 4, 51])
        next_dist   = self.qnetwork_target(next_states).data.cpu() * 
                       support.data.cpu()
        # torch.Size([64])
        next_actions = next_dist.sum(2).max(1)[1]
        # torch.Size([64, 1, 51])
        next_actions = 
        next_actions.unsqueeze(1).unsqueeze(1).expand(next_dist.size(0), 1, 
        next_dist.size(2))
        # these are the batch_size distributions relative to the optimal 
        actions 'next_actions'
        # torch.Size([64, 51])
        next_dist   = next_dist.gather(1, next_actions).squeeze(1)

        rewards = rewards.expand_as(next_dist)
        print('rewards: ', rewards.size())
        dones   = dones.expand_as(next_dist)
        support = support.unsqueeze(0).expand_as(next_dist)
        print('support: ', support.size())

        Tz = rewards + (1 - dones) * GAMMA * support.to(device)
        Tz = Tz.clamp(min=self.V_min, max=self.V_max)
        b  = (Tz - self.V_min) / delta_z
        l  = b.floor().long()
        u  = b.ceil().long()
        print('TYPE: ' , b.type(), l.type(), u.type(), next_dist.type())

        offset = torch.linspace(0, (batch_size - 1) * self.num_atoms, batch_size).long()\
                        .unsqueeze(1).expand(batch_size, self.num_atoms)
        print('offset: ', offset.size())

        proj_dist = torch.zeros(next_dist.size())  
        print('proj_dist_1: ', proj_dist.size())
        proj_dist.view(-1).index_add_(0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1))
        print('proj_dist_2: ', proj_dist.size())
        proj_dist.view(-1).index_add_(0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1))
        print('proj_dist_3: ', proj_dist.size())

        return proj_dist

Почему я получаю эту ошибку?Как мне это исправить?

Большое спасибо Макс

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...