Использование torch.argmax()
(для PyTorch +0,4):
prediction = torch.argmax(tensor, dim=1) # with 'dim' the considered dimension
prediction = prediction.unsqueeze(1) # to reshape from (24, 224, 224) to (24, 1, 224, 224)
Если версия PyTorch ниже 0.4.0, можно использовать tensor.max()
, которая возвращает как максимальные значения, так и их индексы (но которые не дифференцируются по значениям индекса):
_, prediction = tensor.max(dim=1)
prediction = prediction.unsqueeze(1) # to reshape from (24, 224, 224) to (24, 1, 224, 224)