В некоторых приложениях, таких как внимание к слотам (реализовано в Pytorch здесь ), необходимо транслировать по пакетному измерению. Однако я не вижу, как это сделать с помощью функционального API. Например,
import tensorflow as tf
const = tf.ones((1,4))
input = tf.keras.layers.Input((4))
const = tf.broadcast_to(const, input.shape)
вызывает следующую ошибку:
ValueError: Cannot convert a partially known TensorShape to a Tensor: (None, 4)
Таким образом, я прибегаю к подклассу tf.keras.Model
, но я хотел бы сохранить свой код в функциональном API. Кто-нибудь знает, как сделать sh это?