Как транслировать пакетное измерение с помощью функционального API Tensorflow? - PullRequest
0 голосов
/ 02 августа 2020

В некоторых приложениях, таких как внимание к слотам (реализовано в Pytorch здесь ), необходимо транслировать по пакетному измерению. Однако я не вижу, как это сделать с помощью функционального API. Например,

import tensorflow as tf
const = tf.ones((1,4))
input = tf.keras.layers.Input((4))

const = tf.broadcast_to(const, input.shape)

вызывает следующую ошибку:

ValueError: Cannot convert a partially known TensorShape to a Tensor: (None, 4)

Таким образом, я прибегаю к подклассу tf.keras.Model, но я хотел бы сохранить свой код в функциональном API. Кто-нибудь знает, как сделать sh это?

1 Ответ

0 голосов
/ 22 августа 2020

Наконец нашел ответ на этот вопрос, используя tf.keras.backend.shape:

const = tf.ones((1,4))
input = tf.keras.layers.Input((4))

const = tf.broadcast_to(const, [tf.keras.backend.shape(input)[0], 4] )

# Shape of const is now (None, 4)
...