У меня есть хорошо настроенная модель vgg16, обученная керасу.
Я хочу использовать модель в Android. Следовал этому учебнику, и первым шагом было преобразование модели keras в формат файла Tensorflow pb
.
Это имена слоев в моей модели
layer_names=[layer.name for layer in model.layers]
print(layer_names)
output
'input_1', 'block1_conv1', 'block1_conv2', 'block1_pool', 'block2_conv1', 'block2_conv2', 'block2_pool', 'block3_conv1', 'block3_conv2', 'block3_conv3', 'block3_pool', 'block4_conv1', 'block4_conv2', 'block4_conv3', 'block4_pool', 'block5_conv1', 'block5_conv2', 'block5_conv3', 'block5_pool', 'flatten', 'dense_1', 'dropout_1', 'dense_2']
Я использую эту функцию для преобразования модели keras в тензор потока pb. A_Guide_to_Running_Tensorflow_Models_on_ Android
import os
import os.path as path
from keras import backend as K
import tensorflow as tf
from tensorflow.python.tools import freeze_graph
from tensorflow.python.tools import optimize_for_inference_lib
from tensorflow.examples.tutorials.mnist import input_data
def export_model(saver, model, MODEL_NAME, input_node_names, output_node_name):
tf.train.write_graph(K.get_session().graph_def, 'out', \
MODEL_NAME + '_graph.pbtxt')
saver.save(K.get_session(), 'out/' + MODEL_NAME + '.chkp')
freeze_graph.freeze_graph('out/' + MODEL_NAME + '_graph.pbtxt', None, \
False, 'out/' + MODEL_NAME + '.chkp', output_node_name, \
"save/restore_all", "save/Const:0", \
'out/frozen_' + MODEL_NAME + '.pb', True, "")
input_graph_def = tf.GraphDef()
with tf.gfile.Open('out/frozen_' + MODEL_NAME + '.pb', "rb") as f:
input_graph_def.ParseFromString(f.read())
# Resave weights in format suitable for android
output_graph_def = optimize_for_inference_lib.optimize_for_inference(
input_graph_def, input_node_names, [output_node_name],
tf.float32.as_datatype_enum)
with tf.gfile.FastGFile('out/opt_' + MODEL_NAME + '.pb', "wb") as f:
f.write(output_graph_def.SerializeToString())
print("graph saved!")
# first name in list
input_node_names = ['input_1']
# last name in list
output_node_name = 'dense_2'
import tensorflow as tf
model_load_path = 'vgg16-fined-tuned.h5'
model_save_path = './'
model_name = 'vgg16-fined-tuned.pb'
import keras
session = keras.backend.get_session()
init = tf.global_variables_initializer()
session.run(init)
export_model(tf.train.Saver(), model, model_name, input_node_names, output_node_name)
Выход
IndexError Traceback (most recent call last)
in
4 session.run(init)
5
----> 6 export_model(tf.train.Saver(), model, model_name, input_node_names, output_node_name)
in export_model(saver, model, MODEL_NAME, input_node_names, output_node_name)
18 False, 'out/' + MODEL_NAME + '.chkp', output_node_name, \
19 "save/restore_all", "save/Const:0", \
---> 20 'out/frozen_' + MODEL_NAME + '.pb', True, "")
21
22 input_graph_def = tf.GraphDef()
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/tools/freeze_graph.py in freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_whitelist, variable_names_blacklist, input_meta_graph, input_saved_model_dir, saved_model_tags, checkpoint_version)
361 input_saved_model_dir,
362 saved_model_tags.replace(" ", "").split(","),
--> 363 checkpoint_version=checkpoint_version)
364
365
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/tools/freeze_graph.py in freeze_graph_with_def_protos(***failed resolving arguments***)
188 try:
189 saver = saver_lib.Saver(
--> 190 var_list=var_list, write_version=checkpoint_version)
191 except TypeError as e:
192 # `var_list` is required to be a map of variable names to Variable
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py in __init__(self, var_list, reshape, sharded, max_to_keep, keep_checkpoint_every_n_hours, name, restore_sequentially, saver_def, builder, defer_build, allow_empty, write_version, pad_step_number, save_relative_paths, filename)
830 time.time() + self._keep_checkpoint_every_n_hours * 3600)
831 elif not defer_build:
--> 832 self.build()
833 if self.saver_def:
834 self._check_saver_def()
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py in build(self)
842 if context.executing_eagerly():
843 raise RuntimeError("Use save/restore instead of build in eager mode.")
--> 844 self._build(self._filename, build_save=True, build_restore=True)
845
846 def _build_eager(self, checkpoint_path, build_save, build_restore):
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py in _build(self, checkpoint_path, build_save, build_restore)
879 restore_sequentially=self._restore_sequentially,
880 filename=checkpoint_path,
--> 881 build_save=build_save, build_restore=build_restore)
882 elif self.saver_def and self._name:
883 # Since self._name is used as a name_scope by builder(), we are
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py in _build_internal(self, names_to_saveables, reshape, sharded, max_to_keep, keep_checkpoint_every_n_hours, name, restore_sequentially, filename, build_save, build_restore)
485
486 saveables = saveable_object_util.validate_and_slice_inputs(
--> 487 names_to_saveables)
488 if max_to_keep is None:
489 max_to_keep = 0
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saving/saveable_object_util.py in validate_and_slice_inputs(names_to_saveables)
336 # Avoid comparing ops, sort only by name.
337 key=lambda x: x[0]):
--> 338 for converted_saveable_object in saveable_objects_for_op(op, name):
339 _add_saveable(saveables, seen_ops, converted_saveable_object)
340 return saveables
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saving/saveable_object_util.py in saveable_objects_for_op(op, name)
205 else:
206 yield ResourceVariableSaveable(
--> 207 variable, "", name)
208
209
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saving/saveable_object_util.py in __init__(self, var, slice_spec, name)
81 self._var_shape = var.shape
82 if isinstance(var, ops.Tensor):
---> 83 self.handle_op = var.op.inputs[0]
84 tensor = var
85 elif isinstance(var, resource_variable_ops.ResourceVariable):
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in __getitem__(self, i)
2193
2194 def __getitem__(self, i):
-> 2195 return self._inputs[i]
2196
2197 # pylint: enable=protected-access
IndexError: list index out of range