Загрузка официального метафайла мобильной сети с использованием tf.train.import_meta_graph завершается неудачно - PullRequest
0 голосов
/ 11 октября 2019

Я хочу использовать mobilenet_v2 в Tensorflow 1.14 для классификации изображений, но я не хочу строить график с нуля. Поэтому я нахожу контрольный пункт mobilenet_v2 с официального сайта https://www.tensorflow.org/lite/guide/hosted_models. Я скачал версию Mobilenet_V2_1.0_224. Затем я использовал функцию tf.train.import_meta_graph () для загрузки метафайла, но это не удалось. Код довольно прост.

saver = tf.train.import_meta_graph("mobilenet_v2_1.4_224.ckpt.meta", clear_devices=True)

и сообщения об ошибках:

Traceback (most recent call last):
  File "C:\Anaconda\lib\site-packages\tensorflow\python\framework\importer.py", line 427, in import_graph_def
    graph._c_graph, serialized, options)  # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: NodeDef expected inputs 'float, int32' do not match 1 inputs specified; Op<name=CrossReplicaSum; signature=input:T, group_assignment:int32 -> output:T; attr=T:type,allowed=[DT_BFLOAT16, DT_FLOAT, DT_INT32, DT_UINT32]>; NodeDef: {{node TPUReplicate/loop/CrossReplicaSum}}

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "mbv2_test.py", line 10, in <module>
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
  File "C:\Anaconda\lib\site-packages\tensorflow\python\training\saver.py", line 1449, in import_meta_graph
    **kwargs)[0]
  File "C:\Anaconda\lib\site-packages\tensorflow\python\training\saver.py", line 1473, in _import_meta_graph_with_return_elements
    **kwargs))
  File "C:\Anaconda\lib\site-packages\tensorflow\python\framework\meta_graph.py", line 857, in import_scoped_meta_graph_with_return_elements
    return_elements=return_elements)
  File "C:\Anaconda\lib\site-packages\tensorflow\python\util\deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "C:\Anaconda\lib\site-packages\tensorflow\python\framework\importer.py", line 431, in import_graph_def
    raise ValueError(str(e))
ValueError: NodeDef expected inputs 'float, int32' do not match 1 inputs specified; Op<name=CrossReplicaSum; signature=input:T, group_assignment:int32 -> output:T; attr=T:type,allowed=[DT_BFLOAT16, DT_FLOAT, DT_INT32, DT_UINT32]>; NodeDef: {{node TPUReplicate/loop/CrossReplicaSum}}

D:\MSK\project\mnn\0mbnet_tensorflow>python mbv2_test.py
1.14.0
2019-10-11 10:34:48.210684: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
WARNING: Logging before flag parsing goes to stderr.
W1011 10:34:48.215158  9564 deprecation_wrapper.py:119] From mbv2_test.py:10: The name tf.train.import_meta_graph is deprecated. Please use tf.compat.v1.train.import_meta_graph instead.

Traceback (most recent call last):
  File "C:\Anaconda\lib\site-packages\tensorflow\python\framework\importer.py", line 427, in import_graph_def
    graph._c_graph, serialized, options)  # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: NodeDef expected inputs 'float, int32' do not match 1 inputs specified; Op<name=CrossReplicaSum; signature=input:T, group_assignment:int32 -> output:T; attr=T:type,allowed=[DT_BFLOAT16, DT_FLOAT, DT_INT32, DT_UINT32]>; NodeDef: {{node TPUReplicate/loop/CrossReplicaSum}}

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "mbv2_test.py", line 10, in <module>
    saver = tf.train.import_meta_graph("mbv2/mobilenet_v2_1.4_224.ckpt.meta", clear_devices=True)
  File "C:\Anaconda\lib\site-packages\tensorflow\python\training\saver.py", line 1449, in import_meta_graph
    **kwargs)[0]
  File "C:\Anaconda\lib\site-packages\tensorflow\python\training\saver.py", line 1473, in _import_meta_graph_with_return_elements
    **kwargs))
  File "C:\Anaconda\lib\site-packages\tensorflow\python\framework\meta_graph.py", line 857, in import_scoped_meta_graph_with_return_elements
    return_elements=return_elements)
  File "C:\Anaconda\lib\site-packages\tensorflow\python\util\deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "C:\Anaconda\lib\site-packages\tensorflow\python\framework\importer.py", line 431, in import_graph_def
    raise ValueError(str(e))
ValueError: NodeDef expected inputs 'float, int32' do not match 1 inputs specified; Op<name=CrossReplicaSum; signature=input:T, group_assignment:int32 -> output:T; attr=T:type,allowed=[DT_BFLOAT16, DT_FLOAT, DT_INT32, DT_UINT32]>; NodeDef: {{node TPUReplicate/loop/CrossReplicaSum}}

Спасибо, если вы можете помочь мне с этой проблемой.

...