Pytorch предоставляет torch.topk(input, k, dim=None, largest=True, sorted=True)
функцию для вычисления k
наибольших элементов заданного input
тензора вдоль заданного измерения dim
.
У меня есть тензор формы (16, 512, 4096)
, и я использую torch.topk
следующим образом -
# inputs.shape (16L, 512L, 4096L)
dist, idx = torch.topk(inputs, 64, dim=2, largest=False, sorted=False)
# dist.shape (16L, 512L, 64L), idx.shape (16L, 512L, 64L)
Я нашел аналогичную реализацию тензорного потока следующим образом - tf.nn.top_k(input, k=1, sorted=True, name=None)
.
Мой вопрос заключается в том, как включить параметр dim=2
в tf.nn.top_k
, чтобы получить тензор той же формы, который рассчитывается с помощью pytorch?