Вы можете использовать функцию torch.max()
. Таким образом, вы можете сделать что-то вроде
x = torch.Tensor([[-5, 0, -1],
[3, 100, 87],
[17, -34, 2],
[45, 1, 25]])
out, inds = torch.max(x,dim=1)
, и это вернет максимальные значения в каждой строке (измерение 1). Он вернет максимальные значения с их индексами.