Tensorflow: не удается правильно обрезать модель с помощью tf.graph_util.extract_sub_graph () - PullRequest
0 голосов
/ 20 апреля 2019

Я пытаюсь сократить входной конвейер Dataset и все train_op моей модели восстановлены из контрольной точки / .pb. Но с узлами не было сокращено после использования tf.graph_util.extract_sub_graph(). Мой код выглядит так:

# load checkpoint or saved_model
restorer = tf.train.import_meta_graph('./model.meta')
graph = tf.get_default_graph()
graph.as_default()

#print node names
print_nodes_name(graph)

# cut it after the first layer
nodes_to_conserve = ['model/conv1/Relu']

# extract subgraph
subgraph = tf.graph_util.extract_sub_graph(graph.as_graph_def(), nodes_to_conserve)

# for the second time
print_nodes_name(graph)

with tf.Session(graph=tf.graph_util.import_graph_def(subgraph)) as sess:
    ...

Я тестировал простую двухслойную сверточную модель, ожидая, что она разрезает модель между слоями. Когда я печатаю имя узла во второй раз, узлы второго слоя все еще там . enter image description here

Я совершенно запутался с этим extract_sub_graph (). Я ожидал получить что-то вроде этого post , но это не работает. Мои вопросы:

  1. Для node_to_conserve, я должен поставить операции или концы тензора как Relu:0?
  2. Как загрузить сохраненные веса и смещения? До или после extract_sub_graph
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...