Я сталкиваюсь с некоторыми проблемами, связанными с использованием API Java Tensorflow.
По сути, я пытаюсь предсказать некоторые изображения, используя замороженную модель, которую я обучил на Python, но я хочу сделать эти выводы с Tensorflow в Java для некоторых приложений, которые я разработаю позже, если это сработает.
Я начал с экспорта моей модели Python в виде файла .pb, который затем можно загрузить в Tensorflow и использовать в целях вывода, который я протестировал в Python и работает без проблем.
Затем я попытался изменить пример LabelImage.java, представленный в примерах Java Tensorflow, которые можно найти на GitHub (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java).. Я в основном изменил пути к модели и образу, который буду использовать. И после успешногоисправляя некоторые ошибки, код можно было запустить, но я получил эту ошибку:
Exception in thread "main" java.lang.UnsupportedOperationException: Generic conv implementation does not support grouped convolutions for now.
[[{{node conv2d_1/convolution}} = Conv2D[T=DT_FLOAT, data_format="NHWC", dilations=[1, 1, 1, 1], padding="SAME", strides=[1, 1, 1, 1], use_cudnn_on_gpu=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_input_1_0_0, conv2d_1/kernel)]]
Я новичок в Java и Tensorflow в целом, и я попытался найти похожие ошибки, например, ту, которую я получил, и я сделалне могу найти ничего полезного. Интересно, если ошибка пытается сказать мне, что текущий API Tensorflow для Java делаетне поддерживает свертки, что я был бы весьма удивлен, если бы это было так.В любом случае, я немного растерялся из-за того, что я мог бы сделать, чтобы решить эту проблему, и я надеюсь, что кто-то может помочь мне найти исправление.
Некоторые детали : я создал и обучил U-Net модели на Keras и использовал метод от какого-то пользователя на GitHub, который преобразует обученную модель Keras в файл .pb, который может быть перезагружен в Tensorflow и запущен для вывода (пользователь: https://github.com/amir-abdi/keras_to_tensorflow). Эта часть перезагрузки и вывода работаетотлично в Python (я проверял это, чтобы быть уверенным).
Кажется, что ошибка происходит в этом фрагменте кода:
private static float[] executeInceptionGraph(byte[] graphDef, Tensor<Float> image) {
try (Graph g = new Graph()) {
g.importGraphDef(graphDef);
try (Session s = new Session(g);
// Generally, there may be multiple output tensors, all of them must be closed to prevent resource leaks.
Tensor<Float> result =
s.runner().feed("input_1", image).fetch("conv2d_24/Sigmoid").run().get(0).expect(Float.class)) {
final long[] rshape = result.shape();
if (result.numDimensions() != 2 || rshape[0] != 1) {
throw new RuntimeException(
String.format(
"Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
Arrays.toString(rshape)));
}
int nlabels = (int) rshape[1];
return result.copyTo(new float[1][nlabels])[0];
}
}
Этот код не был изменен, как я сказал, что я только что изменилпути ввода, которые указывают на мою модель, и образец изображения для тестирования. Точные детали, которые я изменил, можно найти ниже:
public static void main(String[] args) throws Exception {
System.out.println("TensorFlow version: " + TensorFlow.version());
byte[] graphDef = readAllBytesOrExit(Paths.get("C:\\Users\\joao_\\Documents\\GitHub\\Tensorflow-to-PB\\java_code\\src\\main\\resources\\test.pb"));
byte[] imageBytes = readAllBytesOrExit(Paths.get("C:\\Users\\joao_\\Documents\\GitHub\\Tensorflow-to-PB\\java_code\\src\\main\\resources\\02.png"));
try (Tensor<Float> image = constructAndExecuteGraphToNormalizeImage(imageBytes)) {
float[] labelProbabilities = executeInceptionGraph(graphDef, image);
int bestLabelIdx = maxIndex(labelProbabilities);
}
Я надеюсь, что этой информации может быть достаточно для понимания проблемы.