Как построить график модели tf.keras в Tensorflow-2.0? - PullRequest
5 голосов
/ 20 июня 2019

Я обновился до Tensorflow 2.0, а tf.summary.FileWriter("tf_graphs", sess.graph) нет. Я просматривал некоторые другие вопросы StackOverflow по этому вопросу, и они сказали, чтобы использовать tf.compat.v1.summary etc. Конечно, должен быть способ построить график и визуализировать модель tf.keras в Tensorflow версии 2. Что это такое? Я ищу вывод тензорной доски, как показано ниже. Спасибо!

enter image description here

Ответы [ 2 ]

3 голосов
/ 21 июня 2019

Вы можете визуализировать график любой декорированной функции tf.function, но сначала вы должны отследить ее выполнение.

Визуализация графика модели Keras означает визуализацию ее метода call.

По умолчанию этот метод не оформлен tf.function, поэтому вам нужно заключить вызов модели в правильно оформленную функцию и выполнить его.

import tensorflow as tf

model = tf.keras.Sequential(
    [
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(32, activation="relu"),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(10, activation="softmax"),
    ]
)


@tf.function
def traceme(x):
    return model(x)


logdir = "log"
writer = tf.summary.create_file_writer(logdir)
tf.summary.trace_on(graph=True, profiler=True)
# Forward pass
traceme(tf.zeros((1, 28, 28, 1)))
with writer.as_default():
    tf.summary.trace_export(name="model_trace", step=0, profiler_outdir=logdir)
2 голосов
/ 20 июня 2019

Согласно документам , вы можете использовать Tensorboard для визуализации графиков после обучения вашей модели.

Сначала определите вашу модель и запустите ее.Затем откройте Tensorboard и переключитесь на вкладку График.


Пример минимальной компиляции

Этот пример взят из документации.Сначала определите вашу модель и данные.

# Relevant imports.
%load_ext tensorboard

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from datetime import datetime
from packaging import version

import tensorflow as tf
from tensorflow import keras

# Define the model.
model = keras.models.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'])

(train_images, train_labels), _ = keras.datasets.fashion_mnist.load_data()
train_images = train_images / 255.0

Затем обучите вашу модель.Здесь вам нужно будет определить обратный вызов для Tensorboard, который будет использоваться для визуализации статистики и графиков.

# Define the Keras TensorBoard callback.
logdir="logs/fit/" + datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)

# Train the model.
model.fit(
    train_images,
    train_labels, 
    batch_size=64,
    epochs=5, 
    callbacks=[tensorboard_callback])

После обучения в своей записной книжке запустите

%tensorboard --logdir logs

и переключитесь на графикВкладка на панели навигации:

enter image description here

Вы увидите график, который выглядит примерно так:

enter image description here

...