Как подсчитать общее количество обучаемых параметров в модели тензорного потока, определенной с графиком, загруженным из файла .pb? - PullRequest
0 голосов
/ 03 мая 2018

Я хочу посчитать параметры в модели тензорного потока. Это похоже на существующий вопрос следующим образом.

Как рассчитать общее количество обучаемых параметров в модели тензорного потока?

Но если модель определена с графиком, загруженным из файла .pb, все предложенные ответы не будут работать. В основном я загрузил график с помощью следующей функции.

def load_graph(model_file):

  graph = tf.Graph()
  graph_def = tf.GraphDef()

  with open(model_file, "rb") as f:
    graph_def.ParseFromString(f.read())

  with graph.as_default():
    tf.import_graph_def(graph_def)

  return graph

Одним из примеров является загрузка файла frozen_graph.pb для переподготовки в tenorflow-for-poets-2.

https://github.com/googlecodelabs/tensorflow-for-poets-2

1 Ответ

0 голосов
/ 03 мая 2018

Насколько я понимаю, GraphDef не имеет достаточно информации, чтобы описать Variables. Как объяснено здесь , вам понадобится MetaGraph, который содержит как GraphDef, так и CollectionDef, что является картой, которая может описать Variables. Поэтому следующий код должен дать нам правильное число обучаемых переменных.

Экспорт MetaGraph:

import tensorflow as tf

a = tf.get_variable('a', shape=[1])
b = tf.get_variable('b', shape=[1], trainable=False)
init = tf.global_variables_initializer()
saver = tf.train.Saver([a])

with tf.Session() as sess:
    sess.run(init)
    saver.save(sess, r'.\test')

Импорт MetaGraph и подсчет общего количества обучаемых параметров.

import tensorflow as tf

saver = tf.train.import_meta_graph('test.meta')

with tf.Session() as sess:
    saver.restore(sess, 'test')

total_parameters = 0
for variable in tf.trainable_variables():
    total_parameters += 1
print(total_parameters)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...