Это хороший вопрос, на который я сам споткнулся пару раз.Самый простой ответ заключается в том, что нет никаких гарантий того, что torch.argmax
(или torch.max(x, dim=k)
, который также возвращает индексы, если указан dim) будет последовательно возвращать один и тот же индекс.Вместо этого он вернет любой допустимый индекс к значению argmax, возможно, в случайном порядке.Как обсуждает эта тема на официальном форуме , это считается желаемым поведением.(Я знаю, что есть еще одна ветка, которую я читал некоторое время назад, которая делает это более явным, но я не могу найти его снова).
Сказав, что это поведение неприемлемо для моего сценария использования, я написал следующеефункции, которые найдут левый и правый индексы (помните, что condition
- это объект-функция, который вы передаете):
def __consistent_args(input, condition, indices):
assert len(input.shape) == 2, 'only works for batch x dim tensors along the dim axis'
mask = condition(input).float() * indices.unsqueeze(0).expand_as(input)
return torch.argmax(mask, dim=1)
def consistent_find_leftmost(input, condition):
indices = torch.arange(input.size(1), 0, -1, dtype=torch.float, device=input.device)
return __consistent_args(input, condition, indices)
def consistent_find_rightmost(input, condition):
indices = torch.arange(0, input.size(1), 1, dtype=torch.float, device=input.device)
return __consistent_args(input, condition, indices)
# one example:
consistent_find_leftmost(torch.arange(10).unsqueeze(0), lambda x: x>5)
# will return:
# tensor([6])
Надеюсь, они помогут!(О, и, пожалуйста, дайте мне знать, если у вас есть лучшая реализация, которая делает то же самое)