Как построить модель API подклассов Keras / Tensorflow? - PullRequest
2 голосов
/ 25 апреля 2020

Я создал модель, которая работает правильно, используя API-интерфейс Keras Subclassing. model.summary() также работает правильно. При попытке использовать tf.keras.utils.plot_model() для визуализации архитектуры моей модели, он просто выведет это изображение:

enter image description here

Это почти похоже на шутку от Keras Команда разработчиков. Это полная архитектура:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from sklearn.datasets import load_diabetes
import tensorflow as tf
tf.keras.backend.set_floatx('float64')
from tensorflow.keras.layers import Dense, GaussianDropout, GRU, Concatenate, Reshape
from tensorflow.keras.models import Model

X, y = load_diabetes(return_X_y=True)

data = tf.data.Dataset.from_tensor_slices((X, y)).\
    shuffle(len(X)).\
    map(lambda x, y: (tf.divide(x, tf.reduce_max(x)), y))

training = data.take(400).batch(8)
testing = data.skip(400).map(lambda x, y: (tf.expand_dims(x, 0), y))

class NeuralNetwork(Model):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.dense1 = Dense(16, input_shape=(10,), activation='relu', name='Dense1')
        self.dense2 = Dense(32, activation='relu', name='Dense2')
        self.resha1 = Reshape((1, 32))
        self.gru1 = GRU(16, activation='tanh', recurrent_dropout=1e-1)
        self.dense3 = Dense(64, activation='relu', name='Dense3')
        self.gauss1 = GaussianDropout(5e-1)
        self.conca1 = Concatenate()
        self.dense4 = Dense(128, activation='relu', name='Dense4')
        self.dense5 = Dense(1, name='Dense5')

    def call(self, x, *args, **kwargs):
        x = self.dense1(x)
        x = self.dense2(x)
        a = self.resha1(x)
        a = self.gru1(a)
        b = self.dense3(x)
        b = self.gauss1(b)
        x = self.conca1([a, b])
        x = self.dense4(x)
        x = self.dense5(x)
        return x


skynet = NeuralNetwork()
skynet.build(input_shape=(None, 10))
skynet.summary()

model = tf.keras.utils.plot_model(model=skynet,
         show_shapes=True, to_file='/home/nicolas/Desktop/model.png')

1 Ответ

1 голос
/ 25 апреля 2020

Этого нельзя было сделать, потому что в основном подклассификация моделей, как это реализовано в TensorFlow, ограничена по функциям и возможностям по сравнению с моделями, созданными с использованием Functional / Sequential API (которые называются сетями графов в терминологии TF). Если вы проверите исходный код plot_model, вы увидите следующую проверку в model_to_dot функции (которая вызывается plot_model):

if not model._is_graph_network:
  node = pydot.Node(str(id(model)), label=model.name)
  dot.add_node(node)
  return dot

Как я уже говорил, подклассовые модели не являются графовыми сетями, и поэтому для этих моделей будет построен только узел, содержащий название модели (т. е. то же, что вы наблюдали).

Это уже обсуждалось в Github Проблема и один из разработчиков TensorFlow подтвердили это поведение, предоставив следующий аргумент:

@ omalleyt12 прокомментировал:

Да, в общем, мы не можем ничего предположить о структура подклассовой модели. Если ваша Модель может представлять собой блоки слоев, и вы хотите sh для ее визуализации, мы рекомендуем вам просмотреть Функциональный API

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...