Запуск модели тензорного потока в Android - PullRequest
0 голосов
/ 28 апреля 2018

Я пытаюсь запустить простой классификатор радужной оболочки в приложении для Android. Я создал MLP в keras, преобразовал его в формат .pb и поместил в папку активов. Модель керас:

data = sklearn.datasets.load_iris()
x=data.data
y=data.target
x=np.array(x)
y=np.array(y)



x_train, x_test, y_train, y_test= train_test_split(x, y, test_size = 0.25)


inputs=Input(shape=(4,))
x=Dense(10,activation="relu",name="input_layer")(inputs)
x=Dense(10,activation="relu")(x)
x=Dense(15,activation="relu")(x)
x=Dense(3,activation="softmax",name="output_layer")(x)

model=Model(inputs,x)

sgd = SGD(lr=0.05, momentum=0.9, decay=0.0001, nesterov=False)

model.compile(optimizer=sgd, loss="sparse_categorical_crossentropy",  metrics=["accuracy"])
model.fit(x_train,y_train,batch_size=20, epochs=100, verbose=0)

Код в AndroidStudio (у меня есть 4 поля для ввода чисел, 1 поле для вывода и 1 кнопка. Метод предиката вызывается при нажатии кнопки):

static{
    System.loadLibrary("tensorflow_inference");
}

String model_name ="file:///android_asset/iris_model.pb";
String output_name = "output_layer";
String input_name = "input_data";
TensorFlowInferenceInterface tfinterface;



@Override
protected void onCreate(Bundle savedInstanceState) {
    super.onCreate(savedInstanceState);
    setContentView(R.layout.activity_main);
    tfinterface = new TensorFlowInferenceInterface(getAssets(),model_name);

}




public void predictClick(View v){
    TextView antwort = (TextView)findViewById(R.id.antwort);
    EditText number1 = (EditText)findViewById(R.id.number1);
    EditText number2 = (EditText)findViewById(R.id.number2);
    EditText number3 = (EditText)findViewById(R.id.number3);
    EditText number4 = (EditText)findViewById(R.id.number4);
    Button predict = (Button)findViewById(R.id.button);

    float[] result = {};




    String zahl1 = number1.getText().toString();
    String zahl2 = number2.getText().toString();
    String zahl3 = number3.getText().toString();
    String zahl4 = number4.getText().toString();
    float n1 = Float.parseFloat(zahl1);
    float n2 = Float.parseFloat(zahl2);
    float n3 = Float.parseFloat(zahl3);
    float n4 = Float.parseFloat(zahl4);


    float[] inputs={n1,n2,n3,n4};


    //im pretty sure these lines cause the error
    tfinterface.feed(input_name,inputs,4,1);
    tfinterface.run(new String[]{output_name});
    tfinterface.fetch(output_name,result);



    antwort.setText(Float.toString(result[0]));




}

Сборка выполняется без ошибок, но когда я нажимаю кнопку прогнозирования, приложение зависает. Когда я покидаю линии

tfinterface.feed(input_name,inputs,4,1);
tfinterface.run(new String[]{output_name});
tfinterface.fetch(output_name,result);

приложение работает правильно, поэтому я думаю, что это ошибка.

1 Ответ

0 голосов
/ 10 мая 2018

Входной слой вашей модели называется «input_layer» в коде Python, но «input_data» в коде Java.

Также проверьте вывод logcat. Вы должны получить сообщение об ошибке, в котором говорится, что модель не имеет входного слоя, который вы ищете.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...