Насколько я понимаю, GraphDef
не имеет достаточно информации, чтобы описать Variables
. Как объяснено здесь , вам понадобится MetaGraph
, который содержит как GraphDef
, так и CollectionDef
, что является картой, которая может описать Variables
. Поэтому следующий код должен дать нам правильное число обучаемых переменных.
Экспорт MetaGraph:
import tensorflow as tf
a = tf.get_variable('a', shape=[1])
b = tf.get_variable('b', shape=[1], trainable=False)
init = tf.global_variables_initializer()
saver = tf.train.Saver([a])
with tf.Session() as sess:
sess.run(init)
saver.save(sess, r'.\test')
Импорт MetaGraph и подсчет общего количества обучаемых параметров.
import tensorflow as tf
saver = tf.train.import_meta_graph('test.meta')
with tf.Session() as sess:
saver.restore(sess, 'test')
total_parameters = 0
for variable in tf.trainable_variables():
total_parameters += 1
print(total_parameters)