Следите за выходными данными и средствами весового слоя во время обучения model.fit в Tensorflow Keras. Отображать метрику c из train_step, а не только в конце каждой эпохи - PullRequest
1 голос
/ 20 июня 2020

Большинство методов, которые я вижу, используют обратный вызов для отображения статистики в конце каждой эпохи. Я хочу показывать постоянно обновляемую статистику с каждого шага вместе с остальными показателями.

Это то, что я пробовал до сих пор.

class CustomModel(tf.keras.Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        # Update metrics (includes the metric that tracks the loss)
        self.compiled_metrics.update_state(y, y_pred)
        # Return a dict mapping metric names to current value
        returnDict = {m.name: m.result() for m in self.metrics}

        returnDict['robertaNorm'] = tf.norm(self.get_layer(index=2).get_weights()[-1])
        returnDict['qtNorm'] = tf.norm(self.get_layer(index=3).get_weights()[-1])
        returnDict['ptNorm'] = tf.norm(self.get_layer(index=4).get_weights()[-1])

        returnDict['qFF'] = tf.norm(self.get_layer(index=25).get_weights()[0])
        returnDict['pFF'] = tf.norm(self.get_layer(index=26).get_weights()[0])
        returnDict['qFF2'] = tf.norm(self.get_layer(index=27).get_weights()[0])
        returnDict['pFF2'] = tf.norm(self.get_layer(index=28).get_weights()[0])

        # tf.norm(model.get_layer('qffPost_2').get_weights()[0])

        return returnDict

Однако я получаю сообщения об ошибках, в которых говорится, что к ним нельзя получить доступ на графике.

Я тоже пробовал

returnDict['qFF2'] = tf.norm(self.layers[27].get_weights()[0])

И получил те же результаты.

...