Среднее объединение с окном по последовательностям переменной длины - PullRequest
0 голосов
/ 23 марта 2020

У меня есть тензор in формы (batch_size, feature, steps) и я хочу получить выходной тензор out той же формы путем среднего объединения по измерению по времени (шагам) с размером окна 2k+1, то есть:

out[b,f,t] = 1/(2k+1) sum_{t'=t-k,...,t+k} in[b,f,t']

Для временных шагов, в которых нет k предшествующих и последующих временных шагов, я хочу рассчитать среднее значение только для существующих временных шагов.

Однако последовательности в тензоре имеют переменную длину и дополняются нулями соответственно, длины последовательностей хранятся в другом тензоре (и я могу, например, создать маску с ними). ​​

  • Я знаю, что могу использовать out = tf.nn.avg_pool1d(in, ksize=2k+1, strides=1, padding="SAME", data_format="NCW"), который выполняет мою описанную операцию объединения, однако он не понимает, что мои последовательности заполнены нулями, и не позволяет мне передать маску с длинами последовательностей.
  • Там также tf.keras.layers.GlobalAveragePooling1D, но этот слой всегда объединяет всю последовательность и не позволяет мне указать размер окна.

Как выполнить такое операция с маскировкой и размером окна ?

1 Ответ

0 голосов
/ 04 мая 2020

Насколько я знаю, в TensorFlow такой операции нет. Однако можно использовать комбинацию двух немаскированных операций пула, здесь записанных в псевдокоде:

  1. Пусть seq_mask будет маской последовательности формы (batch_size, time)
  2. Пусть in_pooled - тензор in с немаскированным средним пулом
  3. Пусть seq_mask_pooled - тензор seq_mask с немаскированным средним пулом с тем же размером пула
  4. Получите тензор out следующим образом: Каждый элемент out, где маска последовательности равна 0, также должен быть 0. Любой другой элемент получается путем деления элементов от in_pooled до seq_mask_pooled (не то, что элемент seq_mask_pooled никогда не равен 0, если элемент seq_mask не равен).

Тензор out может быть, например, рассчитан с использованием tf.math.divide_no_nan.

...