Я не обнаружил, что подобные API-интерфейсы могут непосредственно достичь этого, но мы можем использовать tf.slice
для его реализации.
def tf_index_select(input_, dim, indices):
"""
input_(tensor): input tensor
dim(int): dimension
indices(list): selected indices list
"""
shape = input_.get_shape().as_list()
if dim == -1:
dim = len(shape)-1
shape[dim] = 1
tmp = []
for idx in indices:
begin = [0]*len(shape)
begin[dim] = idx
tmp.append(tf.slice(input_, begin, shape))
res = tf.concat(tmp, axis=dim)
return res
Вот пример, чтобы показать эквивалентность.
import tensorflow as tf
import torch
import numpy as np
a = np.arange(2*3*4).reshape(2,3,4)
dim = 1
indices = [0,2]
# array([[[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]],
# [[12, 13, 14, 15],
# [16, 17, 18, 19],
# [20, 21, 22, 23]]])
# pytorch
res = torch.tensor(a).index_select(dim, torch.tensor(indices))
# tensor([[[ 0, 1, 2, 3],
# [ 8, 9, 10, 11]],
# [[12, 13, 14, 15],
# [20, 21, 22, 23]]])
# tensorflow
res = tf_index_select(tf.constant(a), dim, indices)
# tensor([[[ 0, 1, 2, 3],
# [ 8, 9, 10, 11]],
# [[12, 13, 14, 15],
# [20, 21, 22, 23]]])