Как получить размер партии в пользовательской модели Keras TensorFlow (2.0)? - PullRequest
0 голосов
/ 27 апреля 2019

Я подклассифицирую tf.keras.model. Мне нужно переопределить compute_output_shape, иначе я получу NotImplementedError от здесь .

class Custom(tf.keras.Model):
    ...

    def compute_output_shape(self, input_shape):
        # input_shape = (None, ...)
        batch_size = ???
        return (batch_size, ...)

compute_output_shape принимает input_shape в качестве ввода. Однако это не сильно помогает, поскольку размер пакета каким-то образом теряется в TensorFlow.

Если я попытаюсь вернуть фигуру, которая начинается с None так же, как input_shape, я получу TypeError: 'str' object cannot be interpreted as an integer. Простое пропускание размера пакета также не работает.

Размер пакета является переменным, поэтому я не могу просто жестко его кодировать.

...