Tensorflow: фильтрация дубликатов 3D-индекса по их максимальным значениям - PullRequest
0 голосов
/ 07 ноября 2018

Я пытаюсь создать маску фильтра, которая удаляет повторяющиеся индексы из вектора, сравнивая, какое из их соответствующих значений больше.

Мой текущий подход:

  1. Преобразование 3-D индекса в 1-D
  2. Проверьте 1-D Индекс на уникальность
  3. Рассчитать максимальные значения каждого уникального индекса
  4. Сравните максимальные значения с исходными значениями. Если существует такое же значение, сохраните этот трехмерный индекс.

Я хочу получить массив фильтров, чтобы я мог применить boolean_mask и к другим тензорам. Для этого примера маска должна выглядеть следующим образом: [False True True True True].

Мой текущий код работает, если только сами значения не дублированы. Однако, похоже, это тот случай, когда я его использую, поэтому мне нужно найти лучшее решение.

Вот пример того, как выглядит мой код

import tensorflow as tf

# Dummy Input values with same Structure as the real
x_cells   = tf.constant([1,2,3,4,1], dtype=tf.int32)   # Index_1
y_cells   = tf.constant([4,4,4,4,4], dtype=tf.int32)   # Index_2
iou_index = tf.constant([1,2,3,4,1], dtype=tf.int32) # Index_3
iou_max   = tf.constant([1.,2.,3.,4.,5.], dtype=tf.float32) # Values

# my Output should be a mask that is [False True True True True]
# So if i filter this i get e.g. x_cells = [2,3,4,1] or iou_max = [2.,3.,4.,5.]

max_dim_y = tf.constant(10)
max_dim_x = tf.constant(20)
num_anchors = 5
stride = 32

# 1. Transforming the 3D-Index to 1D
tmp = tf.stack([x_cells, y_cells, iou_index], axis=1)
indices = tf.matmul(tmp, [[max_dim_y * num_anchors],     [num_anchors],[1]])

# 2. Looking for unique / duplicate indices
y, idx = tf.unique(tf.squeeze(indices))

# 3. Calculating the maximum values of each unique index.
# An function like unsorted_segment_argmax() would be awesome here
num_segments = tf.shape(y)[0]
ious = tf.unsorted_segment_max(iou_max, idx, num_segments)

iou_max_length = tf.shape(iou_max)[0]
ious_length = tf.shape(ious)[0]

# 4. Compare all max values to original values.
iou_max_tiled = tf.tile(iou_max, [ious_length])
iou_reshaped = tf.reshape(iou_max_tiled, [ious_length, iou_max_length])
iou_max_reshaped = tf.transpose(iou_reshaped)
filter_mask = tf.reduce_any(tf.equal(iou_max_reshaped, ious), -1)
filter_mask = tf.reshape(filter_mask, shape=[-1])

Этот код выше не будет работать, если мы просто изменим значение переменной iou_max в начале на:

x_cells = tf.constant([1,2,3,4,1], dtype=tf.int32)
y_cells = tf.constant([4,4,4,4,4], dtype=tf.int32)
iou_index = tf.constant([1,2,3,4,1], dtype=tf.int32)
iou_max = tf.constant([2.,2.,3.,4.,5.], dtype=tf.float32)

1 Ответ

0 голосов
/ 07 ноября 2018

Мой текущий обходной путь изменил пункт 4 моего вопроса:

По сути, я изменил то, что сравниваю кортежи вместо отдельных значений. Это дает мне возможность логически проверять, находятся ли оба, индекс И значение в оставшихся значениях от 3.

# 4. Compare a Max Value and Indices with original values
rem_index_val_pair = tf.stack([ious, tf.cast(y, dtype=tf.float32)], axis=1)
orig_val_index_pair = tf.stack([iou_max, tf.cast(indices, dtype=tf.float32)], axis=1)

orig_val_index_pair_t = tf.tile(orig_val_index_pair, [1, ious_length])
orig_val_index_pair_s = tf.reshape(orig_val_index_pair_t, [iou_max_length, ious_length, 2])
filter_mask_1 = tf.equal(orig_val_index_pair_s, rem_index_val_pair)
filter_mask_2 = tf.reduce_all(filter_mask_1, -1)
filter_mask_3 = tf.reduce_any(filter_mask_2, -1)

# The orig_val_index_pair_s looks like the following
a =  [[[  2.  71.][  2.  71.][  2.  71.][  2.  71.]
     [[  2. 122.][  2. 122.][  2. 122.][  2. 122.]]
     [[  3. 173.][  3. 173.][  3. 173.][  3. 173.]]
     [[  4. 224.][  4. 224.][  4. 224.][  4. 224.]]
     [[  5.  71.][  5.  71.][  5.  71.][  5.  71.]]]
# I then compare it to the rem_max_val_pair which looks like this.
b =  [[  5.  71.][  2. 122.][  3. 173.][  4. 224.]]

# Using equal(a,b) will now compare each of the values resulting in:
c = [[[False  True][ True False][False False][False False]]
     [[False False][ True  True][False False][False False]]
     [[False False][False False][ True  True][False False]]
     [[False False][False False][False False][ True  True]]
     [[ True  True][False False][False False][False False]]]

# Using tf.reduce_all(c, -1) I can filter the bool pairs with a logical And. 
# (This kicks out my false positives from before).
# Afterwards I can check if the line has any true value by tf.reduce_any().

ИМО, это решение все еще является грязным обходным путем. Так что, если у вас есть предложения по лучшим решениям, поделитесь ими. :)

...