Вы можете использовать tf.range()
и tf.meshgrid()
для создания индексных матриц, а затем использовать tf.where()
с вашим условием для получения индексов, которые его удовлетворяют. Однако, сложная часть будет следующей: вы не можете легко присвоить значения тензору на основе индексов в TF (my_tensor[my_indices] = my_values
).
Обходной путь для вашей проблемы («для всех (i,j,k)
, если pob[i,j,k] != 0
, тогда rob[i,j] = 1
») может быть следующим:
import tensorflow as tf
# Example values for demonstration:
pob_val = [[[0, 0, 0], [1, 0, 0], [1, 0, 1]], [[1, 1, 1], [0, 0, 0], [0, 0, 0]]]
pob = tf.constant(pob_val)
pob_shape = tf.shape(pob)
rob = tf.zeros(pob_shape)
# Get the mask:
mask = tf.cast(tf.not_equal(pob, 0), tf.uint8)
# If there's at least one "True" in mask[i, j, :], make all mask[i, j, :] = True:
mask = tf.cast(tf.reduce_max(mask, axis=-1, keepdims=True), tf.bool)
mask = tf.tile(mask, [1, 1, pob_shape[-1]])
# Apply mask:
rob = tf.where(mask, tf.ones(pob_shape), rob)
with tf.Session() as sess:
rob_eval = sess.run(rob)
print(rob_eval)
# [[[0. 0. 0.]
# [1. 1. 1.]
# [1. 1. 1.]]
#
# [[1. 1. 1.]
# [0. 0. 0.]
# [0. 0. 0.]]]