TensorFlow: эффективный способ получить индекс наименьшего элемента в тензоре, который не равен нулю? - PullRequest
2 голосов
/ 19 мая 2019

Я использую TensorFlow 1.12.У меня есть одномерный тензор tag_mask_sizes, который в основном содержит нули, а также несколько натуральных чисел.Как я могу эффективно получить индекс наименьшего элемента, который не равен нулю?Я попробовал следующее:

tag_mask_sizes_suppressed = tf.map_fn(lambda x: x if tf.not_equal(x, tf.constant(0, dtype=tf.uint8)) else 9999999, tag_mask_sizes)
        smallest_mask_index = tf.argmin(tag_mask_sizes_suppressed)

Однако tf.not_equal() дает булевский тензор, который я не могу эффективно оценить в условии if-else внутри лямбды.Существуют ли другие изящные решения, подобные этому?

Хотя я обычно выполняю с нетерпением, эта проблема возникает внутри функции, которую я использую в tf.Dataset.map(), которая не выполняется с нетерпением.

1 Ответ

1 голос
/ 19 мая 2019

На самом деле ваш код эквивалентен следующему коду.

tag_mask_sizes_suppressed = tf.where(tf.not_equal(tag_mask_sizes, 0),tag_mask_sizes,tag_mask_sizes+9999999)
smallest_mask_index1 = tf.argmin(tag_mask_sizes_suppressed)

Метод векторизации будет значительно быстрее, чем tf.map_fn(). Кроме того, есть некоторый метод векторизации для получения индекса наименьшего элемента в одномерном тензоре, который не равен нулю. Пример:

import tensorflow as tf
# tf.enable_eager_execution()

tag_mask_sizes = tf.constant([2,0,1,3,1,32,0,0,0], dtype=tf.int32)

# approach 1, the disadvantage is that the maximum must be specified and only the first minimum can be found.
tag_mask_sizes_suppressed = tf.where(tf.not_equal(tag_mask_sizes, 0),tag_mask_sizes,tag_mask_sizes+9999999)
smallest_mask_index1 = tf.argmin(tag_mask_sizes_suppressed)

# approach 2, only the first minimum can be found.
tag_mask_sizes_nozeroidx = tf.where(tf.not_equal(tag_mask_sizes, 0))
tag_mask_sizes_suppressed = tf.gather_nd(tag_mask_sizes,tag_mask_sizes_nozeroidx)
smallest_mask_index2 = tag_mask_sizes_nozeroidx[tf.argmin(tag_mask_sizes_suppressed)]

# approach 3, find all minimum
tag_mask_sizes_suppressed = tf.boolean_mask(tag_mask_sizes,tf.not_equal(tag_mask_sizes, 0))
smallest_mask_index3 = tf.squeeze(tf.where(tf.equal(tag_mask_sizes,tf.reduce_min(tag_mask_sizes_suppressed))))

with tf.Session() as sess:
    print(sess.run(smallest_mask_index1))
    print(sess.run(smallest_mask_index2))
    print(sess.run(smallest_mask_index3))

# print
2
[2]
[2 4]
...