Я пытаюсь реализовать Пространственную Трансформаторную Сеть из здесь , и я сталкиваюсь с этой проблемой:
class STNLayer(torch.nn.Module):
def __init__(self, input_size):
super(STNLayer, self).__init__()
self.input_size = input_size
self.localization = nn.Sequential(
nn.Conv2d(self.input_size, 8, kernel_size = 7),
nn.MaxPool2d(2, stride = 2),
nn.ReLU(True),
nn.Conv2d(8, 10, kernel_size = 5),
nn.MaxPool2d(2, stride = 2),
nn.ReLU(True)
)
self.fc_loc = nn.Sequential(
nn.Linear(10 * 12 * 12, 32),
nn.ReLU(True),
#nn.BatchNorm1d(32),
nn.Linear(32, 3*2)
)
# Initialize weights to identity transformation
self.fc_loc[2].weight.data.zero_()
self.fc_loc[2].bias.data = torch.cuda.FloatTensor([1,0,0,0,1,0])
Линия
self.fc_loc[2].bias.data = torch.cuda.FloatTensor([1,0,0,0,1,0])
выдаёт ошибку:
*** AttributeError: module 'torch' has no attribute 'float'
Как мне это исправить?