Как индексировать тензор с формой (batch_size, 200, 256), чтобы получить (batch_size, 1, 256) заданный список тензора индекса с length = batch_size? - PullRequest
1 голос
/ 12 октября 2019

У меня есть выход из слоя LSTM с формой (batch_size, 200, 256), где 200 - длина последовательности токенов, а 256 - выходной размер LSTM. У меня также есть другой тензор с формой (batch_size), который является списком индекса токена, который я хочу вырезать из каждой последовательности образцов в пакете.

Если индекс токена не равен -1, я выделю представление вектора токена (длина = 256). Если индекс токена равен -1, я выдам нулевой вектор (длина = 256).

Ожидаемый результат вывода имеет форму (batch_size, 1, 256). Как мне это сделать?

Спасибо

Вот что я пробовал до сих пор

bidir = concatenate([forward, backward]) # shape = (batch_size, 200, 256) 
dropout = Dropout(params['dropout_rate'])(bidir)
def slice_by_tensor(x):
    matrix_to_slice = x[0]
    index_tensor = x[1]


    out_tensor = tf.where(index_tensor == -1, 
                          tf.zeros(tf.shape(tf.gather(matrix_to_slice, 
                                                      index_tensor, axis=1))), 
                          tf.gather(matrix_to_slice, index_tensor, axis=1))



    return out_tensor


representation_stack0 = Lambda(lambda x: slice_by_tensor(x))([dropout,stack_idx0]) 
# stack_idx0 shape is (batch_size) 
# I got output with shape (batch_size, batch_size, 256) with this code

1 Ответ

0 голосов
/ 12 октября 2019
a=tf.reshape(tf.range(2*3*4),shape=(2,3,4))
#     [[[ 0,  1,  2,  3],
#        [ 4,  5,  6,  7],
#        [ 8,  9, 10, 11]],

#      [[12, 13, 14, 15],
#      [16, 17, 18, 19],
#       [20, 21, 22, 23]]]

b=tf.constant([-1,2]) 

aa=tf.pad(a,[[0,0],[1,0],[0,0]]) 

bb=b+1 

index=tf.stack([tf.range(tf.size(b)),bb],axis=-1) 
res=tf.expand_dims(tf.gather_nd(aa, index),axis=1)
#[[[ 0,  0,  0,  0]],
#[[20, 21, 22, 23]]]

Когда индекс равен -1, нам нужны нули, такие как тензор. Таким образом, мы можем сначала добавить оригинальный тензор вдоль второй оси. Затем увеличьте индексы на 1. После этого, используя tf.gather_nd, вы получите ответ.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...