Вам нужно установить фиксированный размер входного тензора при замораживании вашей модели.
import tensorflow as tf
import os
from tensorflow.python.tools.freeze_graph import freeze_graph
import models
import utils
import image_utils as im
import numpy as np
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string('checkpoint_dir', './checkpoints/photo2cartoon', 'checkpoints directory path')
tf.flags.DEFINE_integer('crop_size', '256', 'crop_size, default: 256')
def export_graph(model_name):
graph = tf.Graph()
with graph.as_default():
a_real = tf.placeholder(tf.float32,shape=([1,FLAGS.crop_size, FLAGS.crop_size, 3]),name='input_image') # <<<< YOU NEED TO DEFINE THIS
#a_real=tf.reshape(a_real,tf.stack([1,FLAGS.crop_size, FLAGS.crop_size, 3]))
a2b = models.generator(a_real, 'a2b',reuse=False, train=False)
saver = tf.train.Saver()
with tf.Session(graph=graph) as sess:
sess.run(tf.global_variables_initializer())
# ------------------------------
# Save graph nodes to text file
# ------------------------------
graph_def=graph.as_graph_def()
# Remove Const nodes.
for i in reversed(range(len(graph_def.node))):
if graph_def.node[i].op == 'Const':
del graph_def.node[i]
for attr in ['T', 'data_format', 'Tshape', 'N', 'Tidx', 'Tdim',
'use_cudnn_on_gpu', 'Index', 'Tperm', 'is_training',
'Tpaddings']:
if attr in graph_def.node[i].attr:
del graph_def.node[i].attr[attr]
# Save as text.
tf.train.write_graph(graph_def, "", "text_graph.pbtxt", as_text=True)
# ------------------------------
# Load variables data
# ------------------------------
latest_ckpt = utils.load_checkpoint(FLAGS.checkpoint_dir, sess, saver)
if latest_ckpt is None:
raise Exception('No checkpoint!')
else:
print('Copy variables from % s' % latest_ckpt)
# -----------------------------------------
# Write data for tensorboard for show graph
# -----------------------------------------
a_real_ipt = np.zeros(shape=[1, FLAGS.crop_size, FLAGS.crop_size, 3])
writer = tf.summary.FileWriter('logs', sess.graph)
writer.close()
# -----------------------------------------
# Write graph output
# -----------------------------------------
# get graph definition
gd = sess.graph.as_graph_def()
# fix batch norm nodes
for node in gd.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in xrange(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] = node.input[index] + '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr: del node.attr['use_locking']
output_graph_def = tf.graph_util.convert_variables_to_constants(sess, gd, ["a2b_generator/Tanh"])
tf.train.write_graph(output_graph_def, 'pretrained', model_name, as_text=False)
def main(unused_argv):
print('photo2cartoon.pb')
export_graph('photo2cartoon.pb')
if __name__ == '__main__':
tf.app.run()