Вот функция, которая делает это:
import tensorflow as tf
def sample_indices(valid, m, seed=None):
valid = tf.convert_to_tensor(valid)
n = tf.size(valid)
# Flatten boolean tensor
valid_flat = tf.reshape(valid, [n])
# Get flat indices where the tensor is true
valid_idx = tf.boolean_mask(tf.range(n), valid_flat)
# Shuffled valid indices
valid_idx_shuffled = tf.random.shuffle(valid_idx, seed=seed)
# Pick sample from shuffled indices
valid_idx_sample = valid_idx_shuffled[:m]
# Unravel indices
return tf.transpose(tf.unravel_index(valid_idx_sample, tf.shape(valid)))
with tf.Graph().as_default(), tf.Session() as sess:
valid = [[ True, True, False, True],
[False, True, True, False],
[False, True, False, False]]
m = 4
print(sess.run(sample_indices(valid, m, seed=0)))
# [[1 1]
# [1 2]
# [0 1]
# [2 1]]
Этот sample_indices
является общим для любой формы булева тензора.Если в вашем случае valid_entries
имеет форму (T, N, 1)
, вы получите тензор с формой (M, 3)
в качестве вывода, хотя вы можете игнорировать последний столбец, поскольку он всегда будет равен нулю (или вместо него можно передать tf.squeeze(valid_entries, axis=2)
).
Примечание: последний tf.transpose
должен просто выводить тензор с формой (sample_size, num_dimensions)
, а не наоборот.Однако, если m
довольно большой и вы не обращаете внимания на порядок измерений, вы можете пропустить его, чтобы сэкономить немного времени и памяти, поскольку (в отличие от его аналога NumPy) tf.transpose
создает совершенно новый тензор.