Tensorflow: Как случайным образом выбрать элементы в соответствии с условием без np.where? - PullRequest
1 голос
/ 02 мая 2019

У меня есть 3 массива тензорных потоков (a, b, valid_entries), которые разделяют первые два измерения [T, N, ?].Один из этих массивов 'valid_entries' имеет форму [T,N,1] с логическими значениями.Я хочу случайным образом выбрать T*M 2 набора индексов (M < N), чтобы valid_entries[t,m] == 1 для всех этих индексов.

Другими словами, для каждого временного шага я хочу случайным образом выбрать Mдопустимые записи из a и b.

Я предполагаю, что в numpy эта задача будет решена следующим образом (для простоты пропустим первое измерение T):

M = 3
N = 5
valid_entries = [[0],[1],[0],[1],[0]]
valid_indices = np.where(a==1)
valid_indices = np.random.select(valid_indices,np.min(len(valid_indices),M))
a_new = a[valid_indices]
b_new = b[valid_indices]
valid_new = valid_entries[valid_indices]

Однако все это должно произойти в Tensorflow.

Заранее огромное спасибо за любую помощь!

1 Ответ

1 голос
/ 02 мая 2019

Вот функция, которая делает это:

import tensorflow as tf

def sample_indices(valid, m, seed=None):
    valid = tf.convert_to_tensor(valid)
    n = tf.size(valid)
    # Flatten boolean tensor
    valid_flat = tf.reshape(valid, [n])
    # Get flat indices where the tensor is true
    valid_idx = tf.boolean_mask(tf.range(n), valid_flat)
    # Shuffled valid indices
    valid_idx_shuffled = tf.random.shuffle(valid_idx, seed=seed)
    # Pick sample from shuffled indices
    valid_idx_sample = valid_idx_shuffled[:m]
    # Unravel indices
    return tf.transpose(tf.unravel_index(valid_idx_sample, tf.shape(valid)))

with tf.Graph().as_default(), tf.Session() as sess:
    valid = [[ True,  True, False,  True],
             [False,  True,  True, False],
             [False,  True, False, False]]
    m = 4
    print(sess.run(sample_indices(valid, m, seed=0)))
    # [[1 1]
    #  [1 2]
    #  [0 1]
    #  [2 1]]

Этот sample_indices является общим для любой формы булева тензора.Если в вашем случае valid_entries имеет форму (T, N, 1), вы получите тензор с формой (M, 3) в качестве вывода, хотя вы можете игнорировать последний столбец, поскольку он всегда будет равен нулю (или вместо него можно передать tf.squeeze(valid_entries, axis=2)).

Примечание: последний tf.transpose должен просто выводить тензор с формой (sample_size, num_dimensions), а не наоборот.Однако, если m довольно большой и вы не обращаете внимания на порядок измерений, вы можете пропустить его, чтобы сэкономить немного времени и памяти, поскольку (в отличие от его аналога NumPy) tf.transpose создает совершенно новый тензор.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...