Выбор строк или элементов на разреженном тензоре - PullRequest
0 голосов
/ 29 ноября 2018

В тензорном потоке, как мы можем сделать tf.gather или tf.gather_nd в разреженном тензоре?Как мы можем извлечь отдельные строки или определенные элементы из разреженного тензора, не превращая его в плотный тензор?

1 Ответ

0 голосов
/ 04 декабря 2018

Это одно из возможных решений, хотя оно все еще дорого по времени и памяти, поэтому, вероятно, оно неосуществимо для большого случая использования:

import tensorflow as tf

def sparse_select_indices(sp_input, indices, axis=0):
    # Only necessary if indices may have non-unique elements
    indices, _ = tf.unique(indices)
    n_indices = tf.size(indices)
    # Only necessary if indices may not be sorted
    indices, _ = tf.math.top_k(indices, n_indices)
    indices = tf.reverse(indices, [0])
    # Get indices for the axis
    idx = sp_input.indices[:, axis]
    # Find where indices match the selection
    eq = tf.equal(tf.expand_dims(idx, 1), tf.cast(indices, tf.int64))
    # Mask for selected values
    sel = tf.reduce_any(eq, axis=1)
    # Selected values
    values_new = tf.boolean_mask(sp_input.values, sel, axis=0)
    # New index value for selected elements
    n_indices = tf.cast(n_indices, tf.int64)
    idx_new = tf.reduce_sum(tf.cast(eq, tf.int64) * tf.range(n_indices), axis=1)
    idx_new = tf.boolean_mask(idx_new, sel, axis=0)
    # New full indices tensor
    indices_new = tf.boolean_mask(sp_input.indices, sel, axis=0)
    indices_new = tf.concat([indices_new[:, :axis],
                             tf.expand_dims(idx_new, 1),
                             indices_new[:, axis + 1:]], axis=1)
    # New shape
    shape_new = tf.concat([sp_input.dense_shape[:axis],
                           [n_indices],
                           sp_input.dense_shape[axis + 1:]], axis=0)
    return tf.SparseTensor(indices_new, values_new, shape_new)

Вот пример использования:

import tensorflow as tf

with tf.Session() as sess:
    # Input
    sp1 = tf.SparseTensor([[0, 1], [2, 3], [4, 5]], [10, 20, 30], [6, 7])
    print(sess.run(tf.sparse.to_dense(sp1)))
    # [[ 0 10  0  0  0  0  0]
    #  [ 0  0  0  0  0  0  0]
    #  [ 0  0  0 20  0  0  0]
    #  [ 0  0  0  0  0  0  0]
    #  [ 0  0  0  0  0 30  0]
    #  [ 0  0  0  0  0  0  0]]

    # Select rows 0, 1, 2
    sp2 = sparse_select_indices(sp1, [0, 1, 2])
    print(sess.run(tf.sparse.to_dense(sp2)))
    # [[ 0 10  0  0  0  0  0]
    #  [ 0  0  0  0  0  0  0]
    #  [ 0  0  0 20  0  0  0]]

    # Select columns 4, 5
    sp3 = sparse_select_indices(sp1, [4, 5], axis=1)
    print(sess.run(tf.sparse.to_dense(sp3)))
    # [[ 0  0]
    #  [ 0  0]
    #  [ 0  0]
    #  [ 0  0]
    #  [ 0 30]
    #  [ 0  0]]
...