Это одно из возможных решений, хотя оно все еще дорого по времени и памяти, поэтому, вероятно, оно неосуществимо для большого случая использования:
import tensorflow as tf
def sparse_select_indices(sp_input, indices, axis=0):
# Only necessary if indices may have non-unique elements
indices, _ = tf.unique(indices)
n_indices = tf.size(indices)
# Only necessary if indices may not be sorted
indices, _ = tf.math.top_k(indices, n_indices)
indices = tf.reverse(indices, [0])
# Get indices for the axis
idx = sp_input.indices[:, axis]
# Find where indices match the selection
eq = tf.equal(tf.expand_dims(idx, 1), tf.cast(indices, tf.int64))
# Mask for selected values
sel = tf.reduce_any(eq, axis=1)
# Selected values
values_new = tf.boolean_mask(sp_input.values, sel, axis=0)
# New index value for selected elements
n_indices = tf.cast(n_indices, tf.int64)
idx_new = tf.reduce_sum(tf.cast(eq, tf.int64) * tf.range(n_indices), axis=1)
idx_new = tf.boolean_mask(idx_new, sel, axis=0)
# New full indices tensor
indices_new = tf.boolean_mask(sp_input.indices, sel, axis=0)
indices_new = tf.concat([indices_new[:, :axis],
tf.expand_dims(idx_new, 1),
indices_new[:, axis + 1:]], axis=1)
# New shape
shape_new = tf.concat([sp_input.dense_shape[:axis],
[n_indices],
sp_input.dense_shape[axis + 1:]], axis=0)
return tf.SparseTensor(indices_new, values_new, shape_new)
Вот пример использования:
import tensorflow as tf
with tf.Session() as sess:
# Input
sp1 = tf.SparseTensor([[0, 1], [2, 3], [4, 5]], [10, 20, 30], [6, 7])
print(sess.run(tf.sparse.to_dense(sp1)))
# [[ 0 10 0 0 0 0 0]
# [ 0 0 0 0 0 0 0]
# [ 0 0 0 20 0 0 0]
# [ 0 0 0 0 0 0 0]
# [ 0 0 0 0 0 30 0]
# [ 0 0 0 0 0 0 0]]
# Select rows 0, 1, 2
sp2 = sparse_select_indices(sp1, [0, 1, 2])
print(sess.run(tf.sparse.to_dense(sp2)))
# [[ 0 10 0 0 0 0 0]
# [ 0 0 0 0 0 0 0]
# [ 0 0 0 20 0 0 0]]
# Select columns 4, 5
sp3 = sparse_select_indices(sp1, [4, 5], axis=1)
print(sess.run(tf.sparse.to_dense(sp3)))
# [[ 0 0]
# [ 0 0]
# [ 0 0]
# [ 0 0]
# [ 0 30]
# [ 0 0]]