Вы можете использовать TensorFlowInferenceInterface
, чтобы делать прогнозы, используя файл .pb.Сначала поместите файл .pb в папку ресурсов вашего приложения.
- В файле build.gradle (Module: app) добавьте следующую зависимость:
implementation 'org.tensorflow:tensorflow-android:1.11.0'
- Затем инициализируйтеTensorFlowInferenceInterface, если имя файла вашей модели - "model.pb", тогда
TensorFlowInferenceInterface tensorFlowInferenceInterface = new TensorFlowInferenceInterface(context.getAssets() , "file:///android_asset/model.pb") ;
tensorFlowInferenceInterface.feed( INPUT_NAME , inputs , 1, 28, 28);
, где INPUT_NAME
- имя вашего входного слоя.1 , 50
- входные размеры.
tensorFlowInferenceInterface.run( new String[]{ OUTPUT_NAME } );
, где OUTPUT_NAME
- имя выходного слоя.
float[] outputs = new float[ nuymber_of_classes ];
tensorFlowInferenceInterface.fetch( OUTPUT_NAME , outputs ) ;
outputs
- это значения с плавающей точкой, предсказанные вашей моделью.
Вот полный код:
TensorFlowInferenceInterface tensorFlowInferenceInterface = new
TensorFlowInferenceInterface(context.getAssets() , "file:///android_asset/model.pb");
tensorFlowInferenceInterface.feed( INPUT_NAME , inputs , 1, 28, 28);
tensorFlowInferenceInterface.run( new String[]{ OUTPUT_NAME } );
float[] outputs = new float[ nuymber_of_classes ];
tensorFlowInferenceInterface.fetch( OUTPUT_NAME , outputs ) ;