Реализуем torch.gather в тензорном потоке - PullRequest
0 голосов
/ 06 марта 2020

Я пытаюсь реализовать torch.gather в тензорном потоке. У меня есть макет кода с использованием массивов numpy, но я не могу понять, как реализовать для тензорного потока. Его нельзя использовать, как написано ниже, потому что, когда я получу фактический код тензорного потока с помощью тензоров 'sr c' и 'индексы', будет иметь неизвестное измерение пакета.

Я бы использовал map_fn, но я нужно фактическое расположение индекса, чтобы выяснить, какое значение вывести в конечный тензор.

Помощь.

Ссылки: Что делает функция сбора в pytorch в терминах непрофессионала?

def gather(src, indices, dim)
    output = np.zeros(index.shape, dtype = src.dtype)
    ranges = [range(i) for i in index.shape]
    for idx in itertools.product(*ranges):
        gather_idx = list(idx)
        gather_idx[dim] = indices[idx]
        output[idx] = src[gather_idx]
...