Это, как оказалось, вызвано неправильной фиксацией графика.Я выложил рабочую версию ниже:
"""Experimenting with Keras VGG16."""
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.applications.vgg16 import preprocess_input
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.layers import Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.preprocessing.image import load_img
from IPython import embed
export_dir = '/tmp/export'
input_size = 64
def LoadInput():
image = load_img('mug.jpg', target_size=(input_size, input_size))
image = img_to_array(image)
input_batch = image.reshape(
(1, image.shape[0], image.shape[1], image.shape[2]))
input_batch = preprocess_input(input_batch)
return input_batch
# stackoverflow.com/questions/45466020/how-to-export-keras-h5-to-tensorflow-pb
def FreezeSession(session,
keep_var_names=None,
output_names=None,
clear_devices=True):
graph = session.graph
with graph.as_default():
freeze_var_names = list(
set(v.op.name for v in tf.global_variables()).difference(
keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.global_variables()]
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ''
frozen_graph = tf.graph_util.convert_variables_to_constants(
session, input_graph_def, output_names, freeze_var_names)
return frozen_graph
# medium.com/
# @pipidog/how-to-convert-your-keras-models-to-tensorflow-e471400b886a
def RunModel():
vgg16 = VGG16(input_shape=(input_size, input_size, 3), include_top=False)
output = Flatten()(vgg16.get_output_at(-1))
model = Model(vgg16.input, output)
print(model.summary())
output_batch = model.predict(LoadInput())
print(output_batch)
# print([v for v in output_batch[0]])
frozen_graph = FreezeSession(
K.get_session(), output_names=[out.op.name for out in model.outputs])
tf.train.write_graph(frozen_graph, '/tmp', 'keras-vgg.pb', as_text=False)
tf.reset_default_graph()
with tf.Session() as session:
graph_def = tf.GraphDef()
graph_def.ParseFromString(open('/tmp/keras-vgg.pb', 'rb').read())
session.graph.as_default()
tf.import_graph_def(graph_def, name='')
# for op in session.graph.get_operations():
# print(op.name)
tensor_input = session.graph.get_tensor_by_name('input_1:0')
tensor_output = session.graph.get_tensor_by_name('flatten/Reshape:0')
output_batch = session.run(tensor_output, {tensor_input: LoadInput()})
print(output_batch)
# print([v for v in output_batch[0]])
def main():
RunModel()
if __name__ == '__main__':
main()