У меня есть идентификаторы входного тензора input_ids
с формой: [B x T]
и соответствующая матрица внедрения с формой [B x T x D]
(B: Batch size, T: Sequence Length, D: Dimension)
.Входные идентификаторы - это словарные идентификаторы, а матрица внедрения содержит соответствующие вложения.
Из матрицы внедрения я хочу выбрать эти элементы с определенными идентификаторами (например, 103
).Было бы легко сделать это, используя tf.where
и tf.gather_nd
, но я не знаю, как это сделать, - организовать результаты в пакете размером [B x N x D]
, где N
- максимальное количество токенов сэтот идентификатор (103
) в последовательности.Я хочу использовать 0 тензоров для заполнения по мере необходимости.
Код может показать это лучше (скажем, B=2, T=8, and D=3
):
import tensorflow as tf
tf.enable_eager_execution()
input_ids = tf.constant([[ 101, 1996, 16360, 103, 1010, 1996, 4223, 1997],
[ 101, 103, 3793, 103, 2443, 2000, 103, 2469]])
embeddings = tf.random_normal((2,8,3))
# input ids have two sequences. first one has one 103 element, while second has 3.
Я хочу выбрать из embeddings
те, которыесоответствуют input_ids==103
и дополняют оставшиеся результаты нулями.Я могу получить это с помощью:
indices= tf.where(tf.equal(input_ids, 103))
result = tf.gather_nd(indices=indices, params=embeddings)
#result.shape==[4x3]
# This will result in a [4x3] matrix where 4 = total number of 103 elements in the batch
# and 3 is their corresponding embeddings dimension
# Now I want to organize this into a batch of the
# same batch size as input, i.e., desired shape=(2x3)
# where first (1x3) row contains all token `103`'s embeddings
# in the first sequence but but second (1x3) row has only
# one token 103 embedding (second sequence has only one 103 token)
# the rest are padded with zeros.
Как правило, это приведет к тензору [M x D]
(M = общее количество 103 токенов в пакете).То, что я хочу, это [B x N x D]
, где (N = максимальное количество 103 токенов в каждой последовательности, для вышеприведенного случая это 3).Надеюсь, описание понятно (вроде сложно объяснить точную проблему).
Как мне этого добиться?