Я использую vgg16, где количество классов равно 3, и для точки данных можно прогнозировать несколько меток.
vgg16 = models.vgg16(pretrained=True)
vgg16.classifier[6]= nn.Linear(4096, 3)
с использованием функции потерь: nn.BCEWithLogitsLoss()
Я могу найти точность поиска в случае проблемы с одной меткой, как
`images, labels = data
images, labels = images.to(device), labels.to(device)
labels = Encode(labels)
outputs = vgg16(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
acc = (100 * correct / total)`
Как найти точность для классификации по нескольким меткам?