Ошибка «IndexError: список индексов вне диапазона» при вызове tenorflow freeze_graph - PullRequest
0 голосов
/ 21 февраля 2020

У меня есть хорошо настроенная модель vgg16, обученная керасу.

Я хочу использовать модель в Android. Следовал этому учебнику, и первым шагом было преобразование модели keras в формат файла Tensorflow pb.

Это имена слоев в моей модели

layer_names=[layer.name for layer in model.layers]
print(layer_names)

output

'input_1', 'block1_conv1', 'block1_conv2', 'block1_pool', 'block2_conv1', 'block2_conv2', 'block2_pool', 'block3_conv1', 'block3_conv2', 'block3_conv3', 'block3_pool', 'block4_conv1', 'block4_conv2', 'block4_conv3', 'block4_pool', 'block5_conv1', 'block5_conv2', 'block5_conv3', 'block5_pool', 'flatten', 'dense_1', 'dropout_1', 'dense_2']

Я использую эту функцию для преобразования модели keras в тензор потока pb. A_Guide_to_Running_Tensorflow_Models_on_ Android

import os
import os.path as path
from keras import backend as K

import tensorflow as tf
from tensorflow.python.tools import freeze_graph
from tensorflow.python.tools import optimize_for_inference_lib

from tensorflow.examples.tutorials.mnist import input_data

def export_model(saver, model, MODEL_NAME, input_node_names, output_node_name):
    tf.train.write_graph(K.get_session().graph_def, 'out', \
        MODEL_NAME + '_graph.pbtxt')

    saver.save(K.get_session(), 'out/' + MODEL_NAME + '.chkp')

    freeze_graph.freeze_graph('out/' + MODEL_NAME + '_graph.pbtxt', None, \
        False, 'out/' + MODEL_NAME + '.chkp', output_node_name, \
        "save/restore_all", "save/Const:0", \
        'out/frozen_' + MODEL_NAME + '.pb', True, "")

    input_graph_def = tf.GraphDef()
    with tf.gfile.Open('out/frozen_' + MODEL_NAME + '.pb', "rb") as f:
        input_graph_def.ParseFromString(f.read())

    # Resave weights in format suitable for android
    output_graph_def = optimize_for_inference_lib.optimize_for_inference(
            input_graph_def, input_node_names, [output_node_name],
            tf.float32.as_datatype_enum)

    with tf.gfile.FastGFile('out/opt_' + MODEL_NAME + '.pb', "wb") as f:
        f.write(output_graph_def.SerializeToString())

    print("graph saved!")
# first name in list
input_node_names = ['input_1']
# last name in list
output_node_name = 'dense_2'
import tensorflow as tf

model_load_path = 'vgg16-fined-tuned.h5'
model_save_path = './'
model_name = 'vgg16-fined-tuned.pb'

import keras
session = keras.backend.get_session()
init = tf.global_variables_initializer()
session.run(init)

export_model(tf.train.Saver(), model, model_name, input_node_names, output_node_name)

Выход

IndexError                                Traceback (most recent call last)
 in 
      4 session.run(init)
      5 
----> 6 export_model(tf.train.Saver(), model, model_name, input_node_names, output_node_name)

 in export_model(saver, model, MODEL_NAME, input_node_names, output_node_name)
     18         False, 'out/' + MODEL_NAME + '.chkp', output_node_name, \
     19         "save/restore_all", "save/Const:0", \
---> 20         'out/frozen_' + MODEL_NAME + '.pb', True, "")
     21 
     22     input_graph_def = tf.GraphDef()

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/tools/freeze_graph.py in freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_whitelist, variable_names_blacklist, input_meta_graph, input_saved_model_dir, saved_model_tags, checkpoint_version)
    361       input_saved_model_dir,
    362       saved_model_tags.replace(" ", "").split(","),
--> 363       checkpoint_version=checkpoint_version)
    364 
    365 

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/tools/freeze_graph.py in freeze_graph_with_def_protos(***failed resolving arguments***)
    188       try:
    189         saver = saver_lib.Saver(
--> 190             var_list=var_list, write_version=checkpoint_version)
    191       except TypeError as e:
    192         # `var_list` is required to be a map of variable names to Variable

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py in __init__(self, var_list, reshape, sharded, max_to_keep, keep_checkpoint_every_n_hours, name, restore_sequentially, saver_def, builder, defer_build, allow_empty, write_version, pad_step_number, save_relative_paths, filename)
    830           time.time() + self._keep_checkpoint_every_n_hours * 3600)
    831     elif not defer_build:
--> 832       self.build()
    833     if self.saver_def:
    834       self._check_saver_def()

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py in build(self)
    842     if context.executing_eagerly():
    843       raise RuntimeError("Use save/restore instead of build in eager mode.")
--> 844     self._build(self._filename, build_save=True, build_restore=True)
    845 
    846   def _build_eager(self, checkpoint_path, build_save, build_restore):

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py in _build(self, checkpoint_path, build_save, build_restore)
    879           restore_sequentially=self._restore_sequentially,
    880           filename=checkpoint_path,
--> 881           build_save=build_save, build_restore=build_restore)
    882     elif self.saver_def and self._name:
    883       # Since self._name is used as a name_scope by builder(), we are

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py in _build_internal(self, names_to_saveables, reshape, sharded, max_to_keep, keep_checkpoint_every_n_hours, name, restore_sequentially, filename, build_save, build_restore)
    485 
    486     saveables = saveable_object_util.validate_and_slice_inputs(
--> 487         names_to_saveables)
    488     if max_to_keep is None:
    489       max_to_keep = 0

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saving/saveable_object_util.py in validate_and_slice_inputs(names_to_saveables)
    336                          # Avoid comparing ops, sort only by name.
    337                          key=lambda x: x[0]):
--> 338     for converted_saveable_object in saveable_objects_for_op(op, name):
    339       _add_saveable(saveables, seen_ops, converted_saveable_object)
    340   return saveables

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saving/saveable_object_util.py in saveable_objects_for_op(op, name)
    205       else:
    206         yield ResourceVariableSaveable(
--> 207             variable, "", name)
    208 
    209 

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saving/saveable_object_util.py in __init__(self, var, slice_spec, name)
     81     self._var_shape = var.shape
     82     if isinstance(var, ops.Tensor):
---> 83       self.handle_op = var.op.inputs[0]
     84       tensor = var
     85     elif isinstance(var, resource_variable_ops.ResourceVariable):

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in __getitem__(self, i)
   2193 
   2194     def __getitem__(self, i):
-> 2195       return self._inputs[i]
   2196 
   2197 # pylint: enable=protected-access

IndexError: list index out of range
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...