Проблема здесь в том, что ваш вводный ввод использует double
в качестве типа данных, тот же тип данных также применяется к результирующему тензору.
weights
вашего слоя self.fully_connected
с другой стороны - float
. При подаче данных через слой применяется матричное умножение, и для этого умножения требуется, чтобы обе матрицы имели одинаковый тип данных.
Итак, у вас есть два решения:
- Вы можете преобразовать свой ввод в число с плавающей точкой:
Изменяя:
gen(torch.from_numpy(np.random.normal(size=100)))
Кому:
gen(torch.from_numpy(np.random.normal(size=100)).float())
Ваш вход, который подается в gen
, будет преобразован в float
, затем.
Полный рабочий код для преобразования входных данных:
from torch import nn
import torch
import numpy as np
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fully_connected = nn.Linear(100, 1024*4*4, bias=False)
def forward(self, zvec):
print(zvec.size())
fc = self.fully_connected(zvec)
return(fc.size())
gen = Generator();
gen(torch.from_numpy(np.random.normal(size=100)).float()) # converting network input to float
- Или, в качестве альтернативы, вы можете конвертировать веса вашего слоя в удвоенные значения:
Если вам нужна двойная точность , вы также можете конвертировать weights
в double
.
Изменить эту строку:
self.fully_connected = nn.Linear(100, 1024*4*4, bias=False)
Просто:
self.fully_connected = nn.Linear(100, 1024*4*4, bias=False).double()
Полный рабочий код для преобразования весов:
from torch import nn
import torch
import numpy as np
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fully_connected = nn.Linear(100, 1024*4*4, bias=False).double() # converting layer weights to double()
def forward(self, zvec):
print(zvec.size())
fc = self.fully_connected(zvec)
return(fc.size())
gen = Generator();
gen(torch.from_numpy(np.random.normal(size=100)))
Таким образом, оба способа должны работать для вас, но если вам не нужна дополнительная точность double
, вы должны использовать float
, поскольку double
требует большей вычислительной мощности.