Классификация изображений - pytorch - PullRequest
0 голосов
/ 07 февраля 2020

Я пытаюсь предсказать особенности, используя предварительно обученную модель. Однако я получаю вывод, как я могу использовать torch.max() для получения интересующих классов. Код, который я пробовал:

... loading model
input = transformation_sequence(sample).unsqueeze(0)
outputs = model(input)
_, predicted = torch.max(outputs,1) #this line returns error

#print of `outputs` variable
[tensor([[ 3.0654, -3.0650]]), tensor([[ 1.5634, -1.5672]]), tensor([[ 1.2867, -1.2888]]), tensor([[ 1.2974, -1.2928]]), tensor([[ 6.4537, -6.4487]]), tensor([[ 2.4851, -2.4710]]), tensor([[ 0.9855, -0.9809]]), tensor([[ 0.3995, -0.4033]]), tensor([[ 0.6301, -0.6276]]), tensor([[ 5.7082, -5.6931]]), tensor([[ 1.9354, -1.9365]]), tensor([[ 0.6091, -0.6074]]), tensor([[ 5.4509, -5.4417]]), tensor([[ 3.7231, -3.7115]]), tensor([[ 4.4494, -4.4361]]), tensor([[ 0.8867, -0.8902]]), tensor([[ 2.7410, -2.7402]]), tensor([[ 5.4919, -5.4909]]), tensor([[ 2.2687, -2.2744]]), tensor([[-0.9695,  0.9723]]), tensor([[ 1.5100, -1.5114]]), tensor([[-2.7077,  2.7140]]), tensor([[ 4.4661, -4.4734]]), tensor([[ 0.4846, -0.4821]]), tensor([[-2.9743,  2.9643]]), tensor([[ 1.3900, -1.3874]]), tensor([[ 7.6764, -7.6742]]), tensor([[ 0.5173, -0.5118]]), tensor([[ 1.3513, -1.3503]]), tensor([[ 2.5381, -2.5356]]), tensor([[ 4.9850, -5.0074]]), tensor([[-2.8397,  2.8484]]), tensor([[ 3.1010, -3.1137]]), tensor([[-0.2374,  0.2406]]), tensor([[ 0.5338, -0.5358]]), tensor([[ 3.4912, -3.4979]]), tensor([[ 1.1957, -1.1876]]), tensor([[ 1.1189, -1.1163]]), tensor([[ 3.6400, -3.6365]]), tensor([[-1.3123,  1.3132]])]

#list of error:

  _, predicted = torch.max(outputs,1)
TypeError: max() received an invalid combination of arguments - got (list, int), but expected one of:
 * (Tensor input)
 * (Tensor input, Tensor other, Tensor out)
 * (Tensor input, int dim, bool keepdim, tuple of Tensors out)

1 Ответ

0 голосов
/ 07 февраля 2020

Ваша модель возвращает список тензоров, а не тензоров. Это можно исправить с помощью torch.cat:

torch.max(torch.cat(outputs),1)

>>> torch.return_types.max(
values=tensor([3.0654, 1.5634, 1.2867]),
indices=tensor([0, 0, 0]))
...