Вы можете сделать это без цикла, используя tf.scatter_nd
следующим образом:
import tensorflow as tf
with tf.Graph().as_default(), tf.Session() as sess:
out = tf.zeros([10, 4], dtype=tf.int32)
rows_tf = tf.constant(
[[1, 2, 5],
[1, 2, 5],
[1, 2, 5],
[1, 4, 6],
[1, 4, 6],
[2, 3, 6],
[2, 3, 6],
[2, 4, 7]], dtype=tf.int32)
columns_tf = tf.constant(
[[1],
[2],
[3],
[2],
[3],
[2],
[3],
[2]], dtype=tf.int32)
# Broadcast columns
columns_bc = tf.broadcast_to(columns_tf, tf.shape(rows_tf))
# Scatter values to indices
scatter_idx = tf.stack([rows_tf, columns_bc], axis=-1)
mask = tf.scatter_nd(scatter_idx, tf.ones_like(rows_tf, dtype=tf.bool), tf.shape(out))
print(sess.run(mask))
Выход:
[[False False False False]
[False True True True]
[False True True True]
[False False True True]
[False False True True]
[False True True True]
[False False True True]
[False False True False]
[False False False False]
[False False False False]]
В качестве альтернативы, вы также можете сделать этоиспользуя только логические операции:
import tensorflow as tf
with tf.Graph().as_default(), tf.Session() as sess:
out = tf.zeros([10, 4], dtype=tf.int32)
rows_tf = tf.constant(
[[1, 2, 5],
[1, 2, 5],
[1, 2, 5],
[1, 4, 6],
[1, 4, 6],
[2, 3, 6],
[2, 3, 6],
[2, 4, 7]], dtype=tf.int32)
columns_tf = tf.constant(
[[1],
[2],
[3],
[2],
[3],
[2],
[3],
[2]], dtype=tf.int32)
# Compare indices
row_eq = tf.equal(tf.range(out.shape[0])[:, tf.newaxis],
rows_tf[..., np.newaxis, np.newaxis])
col_eq = tf.equal(tf.range(out.shape[1])[tf.newaxis, :],
columns_tf[..., np.newaxis, np.newaxis])
# Aggregate
mask = tf.reduce_any(row_eq & col_eq, axis=[0, 1])
print(sess.run(mask))
# Same as before
Однако в принципе это потребует больше памяти.