TL; DR: Как я могу разбить 2D бинарный тензор из 2 меток на экземпляр, на 2 тензора с только 1 меткой на экземпляр, как на этом рисунке:
![enter image description here](https://i.stack.imgur.com/6iB8n.png)
Как часть пользовательской функции потерь, я пытаюсь разделить y-тензор с несколькими метками, с 2 метками на экземпляр, на 2 тензора y с одной меткой на экземпляр. Когда я делаю это на 1D y тензор, этот код прекрасно работает:
y_true = tf.constant([1., 0., 0., 0., 1., 0., 0., 0., 0.])
label_cls = tf.where(tf.equal(y_true, 1.))
idx1, idx2 = tf.split(label_cls,2)
raplace = tf.constant([1.])
y_true_1 = tf.scatter_nd(tf.cast(idx1, dtype=tf.int32), raplace, [tf.size(y_true)])
y_true_2 = tf.scatter_nd(tf.cast(idx2, dtype=tf.int32), raplace, [tf.size(y_true)])
with tf.Session() as sess:
print(sess.run([y_true_1,y_true_2]))
И я получаю:
[array([1., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32), array([0., 0., 0., 0., 1., 0., 0., 0., 0.], dtype=float32)]
Но когда я использую партии в обучении, я получаю эту ошибку:
Invalid argument: Outer dimensions of indices and update must match.
Поскольку мои "тензоры у" являются двумерными, а не одномерными, и в этом случае - idx1, idx2
(индексы) не верны, равно как и форма replace
(обновления),Насколько я понимаю, tf.scatter_nd
может обновлять только первое измерение переменной, так как я могу обойти это? и как я могу получить необходимые индексы для этого?