Выбор индекса в случае конфликта в Pytorch Argmax - PullRequest
2 голосов
/ 13 марта 2019

Я пытался выучить тензорные операции, и это бросило меня в тупик.
Допустим, у меня есть один тензор t:

    t = torch.tensor([
        [1,0,0,2],
        [0,3,3,0],
        [4,0,0,5]
    ], dtype  = torch.float32)

Теперь это тензор ранга 2, и мы можем применить argmax для каждого ранга / измерения. скажем, мы применяем его для дим = 1

t.max(dim = 1)
(tensor([2., 3., 5.]), tensor([3, 2, 3]))

Теперь мы можем видеть, что в результате, как и ожидалось, тензор вдоль dim = 1 имеет 2,3 и 5 в качестве элементов max. Но есть конфликт 3. Есть два значения, которые в точности совпадают.
Как это решается? это произвольно выбран? Есть ли порядок выбора, например, L-R, более высокое значение индекса?
Буду признателен за понимание того, как это решается!

1 Ответ

3 голосов
/ 13 марта 2019

Это хороший вопрос, на который я сам споткнулся пару раз.Самый простой ответ заключается в том, что нет никаких гарантий того, что torch.argmax (или torch.max(x, dim=k), который также возвращает индексы, если указан dim) будет последовательно возвращать один и тот же индекс.Вместо этого он вернет любой допустимый индекс к значению argmax, возможно, в случайном порядке.Как обсуждает эта тема на официальном форуме , это считается желаемым поведением.(Я знаю, что есть еще одна ветка, которую я читал некоторое время назад, которая делает это более явным, но я не могу найти его снова).

Сказав, что это поведение неприемлемо для моего сценария использования, я написал следующеефункции, которые найдут левый и правый индексы (помните, что condition - это объект-функция, который вы передаете):

def __consistent_args(input, condition, indices):
    assert len(input.shape) == 2, 'only works for batch x dim tensors along the dim axis'
    mask = condition(input).float() * indices.unsqueeze(0).expand_as(input)
    return torch.argmax(mask, dim=1)


def consistent_find_leftmost(input, condition):
    indices = torch.arange(input.size(1), 0, -1, dtype=torch.float, device=input.device)
    return __consistent_args(input, condition, indices)


def consistent_find_rightmost(input, condition):
    indices = torch.arange(0, input.size(1), 1, dtype=torch.float, device=input.device)
    return __consistent_args(input, condition, indices)

# one example:
consistent_find_leftmost(torch.arange(10).unsqueeze(0), lambda x: x>5)                                                                                                                                     
# will return: 
# tensor([6])

Надеюсь, они помогут!(О, и, пожалуйста, дайте мне знать, если у вас есть лучшая реализация, которая делает то же самое)

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...