Ошибка при загрузке модели keras с deeplearning4j - PullRequest
2 голосов
/ 13 апреля 2019

Я уже некоторое время пытаюсь загрузить свою модель нейронной сети keras для моего приложения для Android с deeplearning4j.Я искал решения (столько же, сколько существует), но каждое решение вызывает новые ошибки, и я просто не мог заставить эту вещь работать.

В любом случае, я обучил NON Последовательная модель с керасом в Python и сохранена так:

model.save('model.h5')

Сейчас я пытаюсь импортировать эту модель с deeplearning4j в Android Studio.Я перепробовал много возможных вариантов, но сейчас я нахожусь здесь:

String modelPath = new ClassPathResource("res/raw/model.h5").getFile().getPath();
ComputationGraph model = KerasModelImport.importKerasModelAndWeights(modelPath)

Это, однако, вызывает следующую ошибку:

java.lang.NoClassDefFoundError: Failed resolution of: Lorg/bytedeco/javacpp/hdf5;

Как я понимаю, gradle не может разрешитьзависимость hdf5 от org.bytedeco, с которой я согласен, поскольку я исключил hdf5-platform в моей сборке Gradle, но, насколько я знаю, hdf5 не должен даже поддерживаться Android (*).

Я также попытался включить hdf5-platform и запустить тот же код, но это вызывает еще одну ошибку:

java.lang.UnsatisfiedLinkError: Platform "android-arm64" not supported by class org.bytedeco.javacpp.hdf5

Я довольно новичок в концепциях gradle и не знаю Android подробно, но кажетсячто проблема с моими зависимостями gradle.Существует также ограниченный объем информации о deeplearning4j, и я также не могу найти альтернативное решение.

Я также включу свои зависимости gradle, которые у меня есть от этого урока.

implementation (group: 'org.deeplearning4j', name: 'deeplearning4j-core', version: '1.0.0-beta3') {
    exclude group: 'org.bytedeco.javacpp-presets', module: 'opencv-platform'
    exclude group: 'org.bytedeco.javacpp-presets', module: 'leptonica-platform'
    exclude group: 'org.bytedeco.javacpp-presets', module: 'hdf5-platform'
    exclude group: 'org.nd4j', module: 'nd4j-base64'
}
implementation group: 'org.nd4j', name: 'nd4j-native', version: '1.0.0-beta3'
implementation group: 'org.nd4j', name: 'nd4j-native', version: '1.0.0-beta3', classifier: "android-arm"
implementation group: 'org.nd4j', name: 'nd4j-native', version: '1.0.0-beta3', classifier: "android-arm64"
implementation group: 'org.nd4j', name: 'nd4j-native', version: '1.0.0-beta3', classifier: "android-x86"
implementation group: 'org.nd4j', name: 'nd4j-native', version: '1.0.0-beta3', classifier: "android-x86_64"
implementation group: 'org.bytedeco.javacpp-presets', name: 'openblas', version: '0.3.3-1.4.3'
implementation group: 'org.bytedeco.javacpp-presets', name: 'openblas', version: '0.3.3-1.4.3', classifier: "android-arm"
implementation group: 'org.bytedeco.javacpp-presets', name: 'openblas', version: '0.3.3-1.4.3', classifier: "android-arm64"
implementation group: 'org.bytedeco.javacpp-presets', name: 'openblas', version: '0.3.3-1.4.3', classifier: "android-x86"
implementation group: 'org.bytedeco.javacpp-presets', name: 'openblas', version: '0.3.3-1.4.3', classifier: "android-x86_64"
implementation group: 'org.bytedeco.javacpp-presets', name: 'opencv', version: '3.4.3-1.4.3'
implementation group: 'org.bytedeco.javacpp-presets', name: 'opencv', version: '3.4.3-1.4.3', classifier: "android-arm"
implementation group: 'org.bytedeco.javacpp-presets', name: 'opencv', version: '3.4.3-1.4.3', classifier: "android-arm64"
implementation group: 'org.bytedeco.javacpp-presets', name: 'opencv', version: '3.4.3-1.4.3', classifier: "android-x86"
implementation group: 'org.bytedeco.javacpp-presets', name: 'opencv', version: '3.4.3-1.4.3', classifier: "android-x86_64"
implementation group: 'org.bytedeco.javacpp-presets', name: 'leptonica', version: '1.76.0-1.4.3'
implementation group: 'org.bytedeco.javacpp-presets', name: 'leptonica', version: '1.76.0-1.4.3', classifier: "android-arm"
implementation group: 'org.bytedeco.javacpp-presets', name: 'leptonica', version: '1.76.0-1.4.3', classifier: "android-arm64"
implementation group: 'org.bytedeco.javacpp-presets', name: 'leptonica', version: '1.76.0-1.4.3', classifier: "android-x86"
implementation group: 'org.bytedeco.javacpp-presets', name: 'leptonica', version: '1.76.0-1.4.3', classifier: "android-x86_64"

(Как) я должен изменить свои зависимости, чтобы заставить эту модель импортироваться для работы?

Или я должен каким-то образом изменить способ импорта моей модели?

1 Ответ

1 голос
/ 13 апреля 2019

deeplearning4j не может быть лучшим вариантом. Чтобы загрузить модель TensorFlow Keras в Android или даже iOS, вы можете использовать TensorFlow Lite .

Сначала вам нужно преобразовать модель Keras (.h5) в модель TFLite (.tflite)

import tensorflow as tf

converter = tf.lite.TFLiteConverter.from_keras_model_file( 'model.h5' )
tflite_model = converter.convert()
open( 'model.tflite' , 'wb' ).write( tflite_model )

Вы можете сделать следующее:

  1. Если ваша модель должна быть размещена на коме-источнике, который будет загружен вашим приложением, тогда вы можете использовать Firebase ML Kit . Для пользовательских моделей TFLite читайте здесь .

  2. Вы можете сохранить модель TFLite в папке ресурсов приложения, а затем загрузить ее MappedByteBuffer. Доступна зависимость TensorFlow Lite для Android:

    implementation ‘org.tensorflow:tensorflow-lite:1.13.1’
    

Вы можете сослаться на эту кодовую метку и на эту статью .

Вы можете загрузить MappedByteBuffer как:

private MappedByteBuffer loadModelFile(Activity activity) throws IOException {
  AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(getModelPath());
  FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
  FileChannel fileChannel = inputStream.getChannel();
  long startOffset = fileDescriptor.getStartOffset();
  long declaredLength = fileDescriptor.getDeclaredLength();
  return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...