Я работаю над argmax
функцией PyTorch, которая определяется как:
torch.argmax(input, dim=None, keepdim=False)
Рассмотрим пример
a = torch.randn(4, 4)
print(a)
print(torch.argmax(a, dim=1))
Здесь, когда я использую dim = 1 вместо столбца поискавекторы, функция ищет векторы строк, как показано ниже.
print(a) :
tensor([[-1.7739, 0.8073, 0.0472, -0.4084],
[ 0.6378, 0.6575, -1.2970, -0.0625],
[ 1.7970, -1.3463, 0.9011, -0.8704],
[ 1.5639, 0.7123, 0.0385, 1.8410]])
print(torch.argmax(a, dim=1))
tensor([1, 1, 0, 3])
Насколько я понимаю, dim = 0 представляет строки, а dim = 1 представляет столбцы.