Я надеюсь, что еще не слишком поздно для этого ответа, но мне удалось сгенерировать файлы .pb, используя код вывода, предоставленный в репозитории .
Obs: я использую tenorflow 1.4.1 из-за моего графического процессора, так что это, вероятно, не будет работать на более новой версии или потребует некоторых изменений.
Демонстрация вывода загружает график и данные контрольной точки в сессию. Оттуда я мог бы использовать функцию для сохранения сеанса и графика. Вот пример моего кода:
import vggish_input
from tensorflow.python.tools import freeze_graph
def save(sess, directory, filename, saver):
"""
This function saves a checkpoint, based on the current session
"""
if not os.path.exists(directory):
os.makedirs(directory)
filepath = os.path.join(directory, filename)
saver.save(sess, filepath)
return filepath
def save_as_pb(sess, directory, filename, saver):
"""
This function saves a checkpoint, then writes the graph in a pbtxt, and then makes a frozen graph with the chekpoint and the pbtxt
"""
# Save checkpoint to freeze graph later
ckpt_filepath = save(sess, directory=directory, filename=filename, saver=saver)
pbtxt_filename = filename + '.pbtxt'
pbtxt_filepath = os.path.join(directory, pbtxt_filename)
pb_filepath = os.path.join(directory, filename + '.pb')
# This will only save the graph but the variables will not be saved.
tf.train.write_graph(graph_or_graph_def=sess.graph_def, logdir=directory, name=pbtxt_filename, as_text=True)
# Freeze graph, combining the checkpoint and
freeze_graph.freeze_graph(input_graph=pbtxt_filepath, input_saver='', input_binary=False, input_checkpoint=ckpt_filepath, output_node_names=vggish_params.OUTPUT_TENSOR_NAME.split(':')[0], restore_op_name='save/restore_all', filename_tensor_name='save/Const:0', output_graph=pb_filepath, clear_devices=True, initializer_nodes='')
return pb_filepath
Затем я вставил save_as_pb сразу после загрузки модели с контрольной точки в vggish_inference_demo.py файле:
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
with tf.Graph().as_default(), tf.Session(config=config) as sess:
# Define the model in inference mode, load the checkpoint, and
# locate input and output tensors.
vggish_slim.define_vggish_slim(training=False)
vggish_slim.load_vggish_slim_checkpoint(sess, checkpoint)
features_tensor = sess.graph.get_tensor_by_name(
vggish_params.INPUT_TENSOR_NAME)
embedding_tensor = sess.graph.get_tensor_by_name(
vggish_params.OUTPUT_TENSOR_NAME)
saver = tf.train.Saver()
save_as_pb(sess, './saved_vggish/', 'vggish', saver)