Как нарезать тензор с учетом стартовых индексов для каждой строки в TF2.1? - PullRequest
0 голосов
/ 13 апреля 2020

С учетом некоторого (как минимум 2-мерного) ввода, например:

inputs = [['a0', 'a1', 'a2', 'a3', 'a4'],
          ['b0', 'b1', 'b2', 'b3', 'b4'],
          ['c0', 'c1', 'c2', 'c3', 'c4']]

... и другого ввода индексов и размера скалярного окна:

indices = [2, 3, 0]  # representing the starting positions (2nd dimension)
window_size = 2      # fixed-width of each window

Как можно Я получаю windows, начиная с этих индексов в Tensorflow 2? Сначала я расскажу об использовании регулярных срезов, таких как inputs[,start:start+window_size], но это неприменимо, поскольку это позволяет использовать только один начальный индекс для всех строк и не поддерживает варьирование индексов на строку.


Ожидаемый результат для этого образца будет:

output = [['a2', 'a3'],
          ['b3', 'b4'],
          ['c0', 'c1']]

1 Ответ

1 голос
/ 14 апреля 2020

Я предоставляю метод векторизации. Метод векторизации будет значительно быстрее, чем tf.map_fn().

import tensorflow as tf

inputs = tf.constant([['a0', 'a1', 'a2', 'a3', 'a4'],
                      ['b0', 'b1', 'b2', 'b3', 'b4'],
                      ['c0', 'c1', 'c2', 'c3', 'c4']])
indicies = tf.constant([2, 3, 0])
window_size = 2

start_index = tf.sequence_mask(indicies,inputs.shape[-1])
# tf.Tensor(
# [[ True  True False False False]
#  [ True  True  True False False]
#  [False False False False False]], shape=(3, 5), dtype=bool)
end_index = tf.sequence_mask(indicies+window_size,inputs.shape[-1])
# tf.Tensor(
# [[ True  True  True  True False]
#  [ True  True  True  True  True]
#  [ True  True False False False]], shape=(3, 5), dtype=bool)

index = tf.not_equal(start_index,end_index)
# tf.Tensor(
# [[False False  True  True False]
#  [False False False  True  True]
#  [ True  True False False False]], shape=(3, 5), dtype=bool)

result = tf.reshape(tf.boolean_mask(inputs,index),
                    indicies.get_shape().as_list()+[window_size])
print(result)
# tf.Tensor(
# [[b'a2' b'a3']
#  [b'b3' b'b4']
#  [b'c0' b'c1']], shape=(3, 2), dtype=string)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...