Сбой приложения Tensorflow Lite Android с NullPointerException «void org.tensorflow.lite.Interpreter.run (java .lang.Object, java .lang.Object)» - PullRequest
0 голосов
/ 07 марта 2020

Я пытаюсь передать массив значений в мою модель Tensorflow Lite. Приложение аварийно завершает работу после ввода значений в моем приложении, которые затем сохраняются в массиве и передаются в модель. Вот фрагмент кода.

inferButton.setOnClickListener(new View.OnClickListener(){
            @Override
            public void onClick(View view){

                int getText1=Integer.parseInt(inputNumber1.getText().toString());
                int getText2=Integer.parseInt(inputNumber2.getText().toString());
                int getText3=Integer.parseInt(inputNumber3.getText().toString());
                int getText4=Integer.parseInt(inputNumber4.getText().toString());
                int getText5=Integer.parseInt(inputNumber5.getText().toString());
                int getText6=Integer.parseInt(inputNumber6.getText().toString());
                int getText7=Integer.parseInt(inputNumber7.getText().toString());
                int getText8=Integer.parseInt(inputNumber8.getText().toString());
                int getText9=Integer.parseInt(inputNumber9.getText().toString());
                int getText10=Integer.parseInt(inputNumber10.getText().toString());
                int getText11=Integer.parseInt(inputNumber11.getText().toString());
                int getText12=Integer.parseInt(inputNumber12.getText().toString());
                int getText13=Integer.parseInt(inputNumber13.getText().toString());
                int getText14=Integer.parseInt(inputNumber14.getText().toString());

                int attributes[]={getText1,getText2,getText3,getText4,getText5,getText5,getText6,getText7,getText8,getText9,getText10,getText10,getText11,getText12,getText13,getText14};

                int prediction=doInference(attributes);
                //float prediction=doInference(inputNumber.getText().toString());
                outputNumber.setText(prediction);
            }
        });

public int doInference(int [] attributeArray){
        int[] inputVal=new int[14];
        for(int i=0;i<14;i++) {
            inputVal[i] = attributeArray[i];
        }
        //Output shape is [1][1]
        int[][] outputval=new int[1][1];

        //Run inference passing the input shape and getting the output shape
        tflite.run(inputVal, outputval);

        //Inferred value is at [0][0]
        int inferredValue=outputval[0][0];

        return inferredValue;
    }

Моя ошибка: -

E/AndroidRuntime: FATAL EXCEPTION: main
    Process: com.example.appdoctor, PID: 13022
    java.lang.NullPointerException: Attempt to invoke virtual method 'void org.tensorflow.lite.Interpreter.run(java.lang.Object, java.lang.Object)' on a null object reference
        at com.example.appdoctor.Prediction.doInference(Prediction.java:90)
        at com.example.appdoctor.Prediction$1.onClick(Prediction.java:75)
        at android.view.View.performClick(View.java:7125)
        at android.view.View.performClickInternal(View.java:7102)
        at android.view.View.access$3500(View.java:801)
        at android.view.View$PerformClick.run(View.java:27336)
        at android.os.Handler.handleCallback(Handler.java:883)
        at android.os.Handler.dispatchMessage(Handler.java:100)
        at android.os.Looper.loop(Looper.java:214)
        at android.app.ActivityThread.main(ActivityThread.java:7356)
        at java.lang.reflect.Method.invoke(Native Method)
        at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:492)
        at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:930)

Я думаю, что массив может не заполняться значениями, или я делаю ошибку в doInference() метод, не понимая структуру моей нейронной сети.

Вот код python: -

import glob
import os
from keras.models import Sequential, load_model
import numpy as np
import pandas as pd
from keras.layers import Dense
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
import matplotlib.pyplot as plt
import keras as k
import tensorflow as tf
from tensorflow import keras
from tensorflow import lite

df = pd.read_csv("kidney4.csv")
df = df.dropna(axis=0)

for column in df.columns:
        if df[column].dtype == np.number:
            continue
        df[column] = LabelEncoder().fit_transform(df[column])

X = df.drop(["classification"], axis=1)
y = df["classification"]

x_scaler = MinMaxScaler()
x_scaler.fit(X)
column_names = X.columns
X[column_names] = x_scaler.transform(X)

X_train,  X_test, y_train, y_test = train_test_split(X, y, test_size= 0.2, shuffle=True)

model=keras.Sequential([keras.layers.Dense(units=1, input_shape=[14])])
model.compile(optimizer='sgd', loss='mean_squared_error')

model.fit(X_train, y_train, epochs=500)

input_array = np.array([40,8,1,2,0,2,6,10,34,40,16,23,67,25])
input_array_for_prediction = np.expand_dims(input_array,axis=0)

print(model.predict(input_array_for_prediction))

Это скриншот моего model.summary()

enter image description here

Вот весь мой java код: -

import androidx.appcompat.app.AppCompatActivity;

import android.content.*;
import android.content.res.*;
import android.view.View;
import android.widget.*;
import org.tensorflow.lite.*;

import java.io.FileInputStream;
import java.io.IOException;
import java.nio.MappedByteBuffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;

import android.os.Bundle;

public class Prediction extends AppCompatActivity {

    EditText inputNumber1,inputNumber2,inputNumber3,inputNumber4,inputNumber5,inputNumber6,inputNumber7,inputNumber8,inputNumber9,inputNumber10,inputNumber11,inputNumber12,inputNumber13,inputNumber14;
    Button inferButton;
    TextView outputNumber;
    Interpreter tflite;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_prediction);

        inputNumber1=(EditText)findViewById(R.id.editText);
        inputNumber2=(EditText)findViewById(R.id.editText2);
        inputNumber3=(EditText)findViewById(R.id.editText3);
        inputNumber4=(EditText)findViewById(R.id.editText4);
        inputNumber5=(EditText)findViewById(R.id.editText5);
        inputNumber6=(EditText)findViewById(R.id.editText6);
        inputNumber7=(EditText)findViewById(R.id.editText7);
        inputNumber8=(EditText)findViewById(R.id.editText8);
        inputNumber9=(EditText)findViewById(R.id.editText9);
        inputNumber10=(EditText)findViewById(R.id.editText10);
        inputNumber11=(EditText)findViewById(R.id.editText11);
        inputNumber12=(EditText)findViewById(R.id.editText12);
        inputNumber13=(EditText)findViewById(R.id.editText13);
        inputNumber14=(EditText)findViewById(R.id.editText14);
        outputNumber=(TextView)findViewById(R.id.outputNumber);
        inferButton=(Button)findViewById(R.id.predictButton);

        try{
            tflite=new Interpreter(loadModelFile());
        }catch(Exception ex){
            ex.printStackTrace();
        }
        inferButton.setOnClickListener(new View.OnClickListener(){
            @Override
            public void onClick(View view){

                int getText1=Integer.parseInt(inputNumber1.getText().toString());
                int getText2=Integer.parseInt(inputNumber2.getText().toString());
                int getText3=Integer.parseInt(inputNumber3.getText().toString());
                int getText4=Integer.parseInt(inputNumber4.getText().toString());
                int getText5=Integer.parseInt(inputNumber5.getText().toString());
                int getText6=Integer.parseInt(inputNumber6.getText().toString());
                int getText7=Integer.parseInt(inputNumber7.getText().toString());
                int getText8=Integer.parseInt(inputNumber8.getText().toString());
                int getText9=Integer.parseInt(inputNumber9.getText().toString());
                int getText10=Integer.parseInt(inputNumber10.getText().toString());
                int getText11=Integer.parseInt(inputNumber11.getText().toString());
                int getText12=Integer.parseInt(inputNumber12.getText().toString());
                int getText13=Integer.parseInt(inputNumber13.getText().toString());
                int getText14=Integer.parseInt(inputNumber14.getText().toString());

                int attributes[]={getText1,getText2,getText3,getText4,getText5,getText5,getText6,getText7,getText8,getText9,getText10,getText10,getText11,getText12,getText13,getText14};

                int prediction=doInference(attributes);
                //float prediction=doInference(inputNumber.getText().toString());
                outputNumber.setText(prediction);
            }
        });
    }
    public int doInference(int [] attributeArray){
        int[] inputVal=new int[14];
        for(int i=0;i<14;i++) {
            inputVal[i] = attributeArray[i];
        }
        int[][] outputval=new int[1][1];

        //Run inference passing the input shape and getting the output shape
        tflite.run(inputVal, outputval);

        //Inferred value is at [0][0]
        int inferredValue=outputval[0][0];

        return inferredValue;
    }
    private MappedByteBuffer loadModelFile() throws IOException {
        AssetFileDescriptor fileDescriptor=this.getAssets().openFd("kidney_test_model.tflite");
        FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
        FileChannel fileChannel = inputStream.getChannel();
        long startOffset = fileDescriptor.getStartOffset();
        long declaredLength = fileDescriptor.getDeclaredLength();
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
    }

}

Любая помощь приветствуется. Спасибо.

...