Я хочу использовать mobilenet_v2 в Tensorflow 1.14 для классификации изображений, но я не хочу строить график с нуля. Поэтому я нахожу контрольный пункт mobilenet_v2 с официального сайта https://www.tensorflow.org/lite/guide/hosted_models. Я скачал версию Mobilenet_V2_1.0_224. Затем я использовал функцию tf.train.import_meta_graph () для загрузки метафайла, но это не удалось. Код довольно прост.
saver = tf.train.import_meta_graph("mobilenet_v2_1.4_224.ckpt.meta", clear_devices=True)
и сообщения об ошибках:
Traceback (most recent call last):
File "C:\Anaconda\lib\site-packages\tensorflow\python\framework\importer.py", line 427, in import_graph_def
graph._c_graph, serialized, options) # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: NodeDef expected inputs 'float, int32' do not match 1 inputs specified; Op<name=CrossReplicaSum; signature=input:T, group_assignment:int32 -> output:T; attr=T:type,allowed=[DT_BFLOAT16, DT_FLOAT, DT_INT32, DT_UINT32]>; NodeDef: {{node TPUReplicate/loop/CrossReplicaSum}}
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "mbv2_test.py", line 10, in <module>
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
File "C:\Anaconda\lib\site-packages\tensorflow\python\training\saver.py", line 1449, in import_meta_graph
**kwargs)[0]
File "C:\Anaconda\lib\site-packages\tensorflow\python\training\saver.py", line 1473, in _import_meta_graph_with_return_elements
**kwargs))
File "C:\Anaconda\lib\site-packages\tensorflow\python\framework\meta_graph.py", line 857, in import_scoped_meta_graph_with_return_elements
return_elements=return_elements)
File "C:\Anaconda\lib\site-packages\tensorflow\python\util\deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "C:\Anaconda\lib\site-packages\tensorflow\python\framework\importer.py", line 431, in import_graph_def
raise ValueError(str(e))
ValueError: NodeDef expected inputs 'float, int32' do not match 1 inputs specified; Op<name=CrossReplicaSum; signature=input:T, group_assignment:int32 -> output:T; attr=T:type,allowed=[DT_BFLOAT16, DT_FLOAT, DT_INT32, DT_UINT32]>; NodeDef: {{node TPUReplicate/loop/CrossReplicaSum}}
D:\MSK\project\mnn\0mbnet_tensorflow>python mbv2_test.py
1.14.0
2019-10-11 10:34:48.210684: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
WARNING: Logging before flag parsing goes to stderr.
W1011 10:34:48.215158 9564 deprecation_wrapper.py:119] From mbv2_test.py:10: The name tf.train.import_meta_graph is deprecated. Please use tf.compat.v1.train.import_meta_graph instead.
Traceback (most recent call last):
File "C:\Anaconda\lib\site-packages\tensorflow\python\framework\importer.py", line 427, in import_graph_def
graph._c_graph, serialized, options) # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: NodeDef expected inputs 'float, int32' do not match 1 inputs specified; Op<name=CrossReplicaSum; signature=input:T, group_assignment:int32 -> output:T; attr=T:type,allowed=[DT_BFLOAT16, DT_FLOAT, DT_INT32, DT_UINT32]>; NodeDef: {{node TPUReplicate/loop/CrossReplicaSum}}
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "mbv2_test.py", line 10, in <module>
saver = tf.train.import_meta_graph("mbv2/mobilenet_v2_1.4_224.ckpt.meta", clear_devices=True)
File "C:\Anaconda\lib\site-packages\tensorflow\python\training\saver.py", line 1449, in import_meta_graph
**kwargs)[0]
File "C:\Anaconda\lib\site-packages\tensorflow\python\training\saver.py", line 1473, in _import_meta_graph_with_return_elements
**kwargs))
File "C:\Anaconda\lib\site-packages\tensorflow\python\framework\meta_graph.py", line 857, in import_scoped_meta_graph_with_return_elements
return_elements=return_elements)
File "C:\Anaconda\lib\site-packages\tensorflow\python\util\deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "C:\Anaconda\lib\site-packages\tensorflow\python\framework\importer.py", line 431, in import_graph_def
raise ValueError(str(e))
ValueError: NodeDef expected inputs 'float, int32' do not match 1 inputs specified; Op<name=CrossReplicaSum; signature=input:T, group_assignment:int32 -> output:T; attr=T:type,allowed=[DT_BFLOAT16, DT_FLOAT, DT_INT32, DT_UINT32]>; NodeDef: {{node TPUReplicate/loop/CrossReplicaSum}}
Спасибо, если вы можете помочь мне с этой проблемой.