Как выполнить TensoFlow 2 Keras Sequential в нетерпеливом режиме при «компиляции»? - PullRequest
0 голосов
/ 06 ноября 2019

Я хочу построить убыток "питоническим" способом, используя стремительное выполнение TF2, но даже в нетерпеливом режиме Keras передает непростые тензоры.

Код:

    def conditional_loss(self, y_true, y_pred):
        print(y_true)
        return 0

    def define_model(self):
        self.model = keras.Sequential([
            keras.layers.Dense(units=768),
            keras.layers.BatchNormalization(),
            keras.layers.ReLU(),
            keras.layers.Dropout(0.2),
            keras.layers.Dense(units=128),
            keras.layers.BatchNormalization(),
            keras.layers.ReLU(),
            keras.layers.Dropout(0.2),
            keras.layers.Dense(units=5, activation='softmax')
        ])

        self.model.compile(optimizer='adam',
                           loss=self.conditional_loss,
                           metrics=[self.conditional_loss, 
                                    keras.metrics.sparse_categorical_accuracy]
                           )
        self.model.fit(
            self.train_dataset,
            epochs=10,
            validation_data=self.test_dataset,
            callbacks=[tensorboard_callback, model_callback],
        )

Если я печатаю y_true в conditional_loss TF печатает не нетерпеливый тензор.

Tensor("metrics/conditional_loss/Cast:0", shape=(None, 1), dtype=float32)

Если я строю свой собственный keras.Model(), я могу вызвать его с аргументом dynamic=True, чтобы включитьнетерпеливое исполнение. ( Ссылки ). Существует способ сделать это в keras.Sequential()?

1 Ответ

0 голосов
/ 06 ноября 2019

Для этого вам нужно вызвать model.compile() с аргументом run_eagerly=True. Следующий пример вопроса:

self.model.compile(optimizer='adam',
                           loss=self.conditional_loss,
                           metrics=[self.conditional_loss, 
                                    keras.metrics.sparse_categorical_accuracy],
                           run_eagerly=True
                           )
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...