Выберите элемент из другого списка в tenorflow - PullRequest
0 голосов
/ 07 марта 2020

Я работаю над tenorflow и у меня возникает следующая проблема

import tensorflow as tf
import numpy as np
from tensorflow.keras import losses
from tensorflow import nn

#2*3
label = np.array([[2, 0, 1], [0, 2, 1]])
#2*3*3
logit = np.array([[[.9, .5, .05], [.35, .01, .3], [.45, .91, .94]], 
         [[.05, .2, .4], [.05, .29, .6], [.35, .01, .02]]])


#find the value corresponding to label index by row
output = nn.log_softmax(logit)

И у меня есть

output = tf.Tensor(
[[[-0.74085818 -1.14085818 -1.59085818]
  [-0.97945321 -1.31945321 -1.02945321]
  [-1.43897936 -0.97897936 -0.94897936]]

 [[-1.27561467 -1.12561467 -0.92561467]
  [-1.38741927 -1.14741927 -0.83741927]
  [-0.88817684 -1.22817684 -1.21817684]]], shape=(2, 3, 3), dtype=float64)

Я хочу выбрать элемент из output по индексам из label. То есть мой окончательный результат должен быть

[[1.59085822 0.97945321 0.97897935]  #2, 0, 1
[1.27561462 0.83741927 1.22817683]], #0, 2, 1
shape=(2, 3), dtype=float64)

1 Ответ

0 голосов
/ 07 марта 2020

Вы не можете сделать это напрямую. Правильный способ добиться этого - сначала нанести одну горячую кодировку на ярлык. Затем используйте tf.boolean_mask для выбора из выходных логов.

Вот пример:

import tensorflow as tf
import numpy as np
from tensorflow.keras import losses
from tensorflow import nn

#2*3
label = np.array([[2, 0, 1], [0, 2, 1]])
#2*3*3
logit = np.array([[[.9, .5, .05], [.35, .01, .3], [.45, .91, .94]], 
         [[.05, .2, .4], [.05, .29, .6], [.35, .01, .02]]])


#find the value corresponding to label index by row
output = nn.log_softmax(logit)

one_hot = tf.one_hot(label, 3, dtype=tf.int32)
# <tf.Tensor: shape=(2, 3, 3), dtype=int32, numpy=
# array([[[0, 0, 1],
#         [1, 0, 0],
#         [0, 1, 0]],
# 
#        [[1, 0, 0],
#         [0, 0, 1],
#         [0, 1, 0]]], dtype=int32)>

result_vec = tf.boolean_mask(output, one_hot) # The result is a vector
# <tf.Tensor: shape=(6,), dtype=float64, numpy=
# array([-1.59085818, -0.97945321, -0.97897936, -1.27561467, -0.83741927,
#        -1.22817684])>

result = tf.reshape(result_vec, label.shape)

В результате вы получите: (вы пропустили отрицательные знаки в своем вопросе?)

<tf.Tensor: shape=(2, 3), dtype=float64, numpy=
array([[-1.59085818, -0.97945321, -0.97897936],
       [-1.27561467, -0.83741927, -1.22817684]])>

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...