Есть ли эквивалентная функция pytorch с именем "index_select" в тензорном потоке - PullRequest
1 голос
/ 19 октября 2019

Я пытался перевести код Pytorch в тензорный поток. Итак, я хочу знать, есть ли эквивалентная функция pytorch с именем "index_select" в тензор потока

1 Ответ

1 голос
/ 19 октября 2019

Я не обнаружил, что подобные API-интерфейсы могут непосредственно достичь этого, но мы можем использовать tf.slice для его реализации.

def tf_index_select(input_, dim, indices):
    """
    input_(tensor): input tensor
    dim(int): dimension
    indices(list): selected indices list
    """
    shape = input_.get_shape().as_list()
    if dim == -1:
        dim = len(shape)-1
    shape[dim] = 1

    tmp = []
    for idx in indices:
        begin = [0]*len(shape)
        begin[dim] = idx
        tmp.append(tf.slice(input_, begin, shape))
    res = tf.concat(tmp, axis=dim)

    return res

Вот пример, чтобы показать эквивалентность.

import tensorflow as tf
import torch
import numpy as np

a = np.arange(2*3*4).reshape(2,3,4)
dim = 1
indices = [0,2]
# array([[[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11]],

#        [[12, 13, 14, 15],
#         [16, 17, 18, 19],
#         [20, 21, 22, 23]]])

# pytorch
res = torch.tensor(a).index_select(dim, torch.tensor(indices))
# tensor([[[ 0,  1,  2,  3],
#          [ 8,  9, 10, 11]],

#         [[12, 13, 14, 15],
#          [20, 21, 22, 23]]])

# tensorflow
res = tf_index_select(tf.constant(a), dim, indices)
# tensor([[[ 0,  1,  2,  3],
#          [ 8,  9, 10, 11]],

#         [[12, 13, 14, 15],
#          [20, 21, 22, 23]]])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...