Используя ModuleDict, у меня есть: Тип ввода (torch.cuda.FloatTensor) и тип веса (torch.FloatTensor) должны быть одинаковыми - PullRequest
0 голосов
/ 22 апреля 2020

Я пытаюсь в моей __init__ функции:


        self.downscale_time_conv = np.empty(8, dtype=object)
        for i in range(8):
            self.downscale_time_conv[i] = torch.nn.ModuleDict({})

Но в моем forward у меня есть:

        down_out = False
        for i in range(8):
            if not down_out:
                down_out = self.downscale_time_conv[i][side](inputs)
            else:
                down_out += self.downscale_time_conv[i][side](inputs)

, и я получаю:

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

1 Ответ

1 голос
/ 22 апреля 2020
        self.downscale_time_conv = torch.nn.ModuleList()
        for i in range(8):
            self.downscale_time_conv.append(torch.nn.ModuleDict({}))

это решило это. Видимо мне нужно было использовать ModuleList

...