Как я могу использовать Tensorflows scatter_nd во втором измерении 2D-тензора? - PullRequest
0 голосов
/ 05 октября 2019

TL; DR: Как я могу разбить 2D бинарный тензор из 2 меток на экземпляр, на 2 тензора с только 1 меткой на экземпляр, как на этом рисунке:

enter image description here

Как часть пользовательской функции потерь, я пытаюсь разделить y-тензор с несколькими метками, с 2 метками на экземпляр, на 2 тензора y с одной меткой на экземпляр. Когда я делаю это на 1D y тензор, этот код прекрасно работает:

y_true = tf.constant([1., 0., 0., 0., 1., 0., 0., 0., 0.])
label_cls = tf.where(tf.equal(y_true, 1.))
idx1, idx2 = tf.split(label_cls,2)
raplace = tf.constant([1.])
y_true_1 = tf.scatter_nd(tf.cast(idx1, dtype=tf.int32), raplace, [tf.size(y_true)]) 
y_true_2 = tf.scatter_nd(tf.cast(idx2, dtype=tf.int32), raplace, [tf.size(y_true)])  

with tf.Session() as sess:
    print(sess.run([y_true_1,y_true_2]))

И я получаю:

[array([1., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32), array([0., 0., 0., 0., 1., 0., 0., 0., 0.], dtype=float32)]

Но когда я использую партии в обучении, я получаю эту ошибку:

Invalid argument: Outer dimensions of indices and update must match.

Поскольку мои "тензоры у" являются двумерными, а не одномерными, и в этом случае - idx1, idx2 (индексы) не верны, равно как и форма replace (обновления),Насколько я понимаю, tf.scatter_nd может обновлять только первое измерение переменной, так как я могу обойти это? и как я могу получить необходимые индексы для этого?

1 Ответ

0 голосов
/ 06 октября 2019

Я чувствую, что вы идете по запутанному пути. Вот мое решение. Почувствуйте, что это более просто, чем тот, с которым вы пытаетесь пойти (Попытка с тф 1.14).

import tensorflow as tf

y_true = tf.constant([[1, 0, 1, 0],[0, 1, 1, 0]])
_, label_inds = tf.math.top_k(y_true, k=2)
idx1, idx2 = tf.split(label_inds,2, axis=1)

y_true_1 = tf.one_hot(idx1, depth=4)
y_true_2 = tf.one_hot(idx2, depth=4)

with tf.Session() as sess:

    print(sess.run([y_true_1, y_true_2]))

Итак, идея в том, что вы получаете индексы верхних 2 меток для каждой строки. Затем разделите это на 2 столбца, используя tf.split. А затем используйте one_hot для преобразования этих индексов обратно в одноручные векторы.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...