На самом деле ваш код эквивалентен следующему коду.
tag_mask_sizes_suppressed = tf.where(tf.not_equal(tag_mask_sizes, 0),tag_mask_sizes,tag_mask_sizes+9999999)
smallest_mask_index1 = tf.argmin(tag_mask_sizes_suppressed)
Метод векторизации будет значительно быстрее, чем tf.map_fn()
. Кроме того, есть некоторый метод векторизации для получения индекса наименьшего элемента в одномерном тензоре, который не равен нулю. Пример:
import tensorflow as tf
# tf.enable_eager_execution()
tag_mask_sizes = tf.constant([2,0,1,3,1,32,0,0,0], dtype=tf.int32)
# approach 1, the disadvantage is that the maximum must be specified and only the first minimum can be found.
tag_mask_sizes_suppressed = tf.where(tf.not_equal(tag_mask_sizes, 0),tag_mask_sizes,tag_mask_sizes+9999999)
smallest_mask_index1 = tf.argmin(tag_mask_sizes_suppressed)
# approach 2, only the first minimum can be found.
tag_mask_sizes_nozeroidx = tf.where(tf.not_equal(tag_mask_sizes, 0))
tag_mask_sizes_suppressed = tf.gather_nd(tag_mask_sizes,tag_mask_sizes_nozeroidx)
smallest_mask_index2 = tag_mask_sizes_nozeroidx[tf.argmin(tag_mask_sizes_suppressed)]
# approach 3, find all minimum
tag_mask_sizes_suppressed = tf.boolean_mask(tag_mask_sizes,tf.not_equal(tag_mask_sizes, 0))
smallest_mask_index3 = tf.squeeze(tf.where(tf.equal(tag_mask_sizes,tf.reduce_min(tag_mask_sizes_suppressed))))
with tf.Session() as sess:
print(sess.run(smallest_mask_index1))
print(sess.run(smallest_mask_index2))
print(sess.run(smallest_mask_index3))
# print
2
[2]
[2 4]