Я подклассифицирую tf.keras.model
. Мне нужно переопределить compute_output_shape
, иначе я получу NotImplementedError
от здесь .
class Custom(tf.keras.Model):
...
def compute_output_shape(self, input_shape):
# input_shape = (None, ...)
batch_size = ???
return (batch_size, ...)
compute_output_shape
принимает input_shape
в качестве ввода. Однако это не сильно помогает, поскольку размер пакета каким-то образом теряется в TensorFlow.
Если я попытаюсь вернуть фигуру, которая начинается с None
так же, как input_shape
, я получу TypeError: 'str' object cannot be interpreted as an integer
. Простое пропускание размера пакета также не работает.
Размер пакета является переменным, поэтому я не могу просто жестко его кодировать.