Ниже приведен код, который я использую -
model_ft = models.vgg16(pretrained=True)
for param in model_ft.parameters():
param.requires_grad = False
class Vgg_added_features(nn.Module):
def __init__(self, originalModel):
super(Vgg_added_features, self).__init__()
self.features = nn.Sequential(*list(originalModel.features)[:-1])
self.classifier = nn.Linear(512*512, num_classes)
#self.avg_pool = nn.AdaptiveAvgPool2d((7,7))
def forward(self, x):
print(x.shape)
x = self.features(x).view(-1,512,12*14*14)
print(x.shape)
x = torch.matmul(x, x.permute(0,2,1)).view(-1,512*512)/12*14*14.0
print(x.shape)
x = torch.mul(torch.sign(x),torch.sqrt(torch.abs(x)+1e-12))
print(x.shape)
x = F.normalize(x, p=2, dim=1)
print(x.shape)
x = self.classifier(x)
print(x.shape)
return x
model = Vgg_added_features(model_ft)
print(model)
Ошибка - ValueError: Ожидаемый входной сигнал batch_size (1) будет соответствовать целевому batch_size (12).
Входной дим 224
Вывод для операторов печати:
torch.Size ([12, 3, 224, 224])
torch.Size ([1, 512, 2352])
torch.Size ([1, 262144])
torch.Size ([1, 262144])
torch.Size ([1, 262144])
Размер горелки ([1, 62])