Экспорт модели Keras в protobuf с выходной формой (None, 2) - PullRequest
0 голосов
/ 24 января 2020

У меня есть модель Keras, которую я пытаюсь экспортировать в ProtoBuf

Последняя пара слоев выглядит следующим образом:

features (Dense)                (None, 128)          49280       concatenate_1[0][0]              
__________________________________________________________________________________________________
gaze_target (Dense)             (None, 2)            258         features[0][0]      

Я пытаюсь экспортировать ее так:

sess = K.get_session()

constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), 'gaze_target')
graph_io.write_graph(constant_graph, 'export', 'output.pb', as_text=False)

Это ошибка с этим:

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/framework/graph_util_impl.py in extract_sub_graph(graph_def, dest_nodes)
    191 
    192   if isinstance(dest_nodes, six.string_types):
--> 193     raise TypeError("dest_nodes must be a list.")
    194 
    195   name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(

TypeError: dest_nodes must be a list.

Как экспортировать эту модель в ProtoBuf? (В конечном итоге для использования в SageMaker)

1 Ответ

0 голосов
/ 28 января 2020

Благодаря коллеге на работе мы решили это. Параметр метода graph_util.convert_variables_to_constants - это не имя слоя, а имя операции (op.name).

Правильный код:

sess = K.get_session()

outputs = [out.op.name for out in model.outputs] # Note this new line

constant_graph = graph_util.convert_variables_to_constants(sess, 
                                             sess.graph.as_graph_def(), 
                                             outputs)

graph_io.write_graph(constant_graph, 'export', 'output.pb', as_text=False)

...