Использование модели TensorFlow 2.1.0, построенной на Python в Java TensorFlow 1.15 | В графике нет операции с именем [input] - PullRequest
1 голос
/ 07 мая 2020

У меня есть модель, написанная на Python 3.7 с использованием TensorFlow 2.1.0. и я пытаюсь использовать его в приложении Java (используя TensorFlow 1.4), однако модель не принимает ввод. Я бы предположил, что это проблема совместимости, но модель успешно загружается в Java. Я пробовал использовать keras.Sequential и keras.Model, но, похоже, это не имеет значения. Я видел, как tf.placeholder используется в TF v1, но понимаю, что замена v2 - это tf.keras.Input.

Python:

#method1
model = tf.keras.Sequential([
    tf.keras.Input(name='input', shape=(60,), dtype=tf.dtypes.float32),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(30, activation='relu'),
    tf.keras.layers.Dense(10, activation='relu'),
    tf.keras.layers.Dense(3, activation='softmax', name='output')
])
#method 2
inputs = tf.keras.Input(name='input', shape=(60,), dtype=tf.dtypes.float32)
outputs = tf.keras.layers.Dense(3, activation='softmax')(inputs)
model = tf.keras.Model(inputs, outputs)

Java:

Session.Runner runner = session.runner();
runner.feed("input", Tensor.create(testData));        

List<Tensor<?>> tensors = runner.fetch("output").run();
System.out.println("Answer is: " + tensors.get(0).floatValue());

Исключение:

2020-05-07 01:32:23.596732: I tensorflow/cc/saved_model/loader.cc:311] SavedModel load for tags { serve }; Status: success. Took 50986 microseconds.
Exception in thread "main" java.lang.IllegalArgumentException: No Operation named [input] in the Graph
    at org.tensorflow.Session$Runner.operationByName(Session.java:380)
    at org.tensorflow.Session$Runner.parseOutput(Session.java:389)
    at org.tensorflow.Session$Runner.feed(Session.java:131)
    at com.treyyoder.smurge.ml.TensorFlowTest.main(TensorFlowTest.java:40)

!!!!!!!!!!!!!!!!!!!!!!! ОБНОВЛЕНИЕ !!!!!!!!!!!!!!!!!!!!!!

Per @ karl-lessard предложение, я включил org.tensorflow:proto для проверки MetaGraphDef

MetaGraphDef составляет ~ 15 тыс. Строк, это был полезный бит:

node {
    name: "StatefulPartitionedCall"
    op: "StatefulPartitionedCall"
    input: "serving_default_input"
    input: "dense/kernel"
    input: "dense/bias"
    input: "dense_1/kernel"
    input: "dense_1/bias"
    input: "output/kernel"
    input: "output/bias"
    attr {
      key: "_gradient_op_type"
      value {
        s: "PartitionedCallUnused"
      }
    }
    attr {
      key: "f"
      value {
        func {
          name: "__inference_signature_wrapper_9526"
        }
      }
    }
    attr {
      key: "Tout"
      value {
        list {
          type: DT_FLOAT
        }
      }
    }
    attr {
      key: "config_proto"
      value {
        s: "\n\a\n\003CPU\020\001\n\a\n\003GPU\020\0012\005*\0010J\0008\001"
      }
    }
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: -1
            }
            dim {
              size: 3
            }
          }
        }
      }
    }
    attr {
      key: "Tin"
      value {
        list {
          type: DT_FLOAT
          type: DT_RESOURCE
          type: DT_RESOURCE
          type: DT_RESOURCE
          type: DT_RESOURCE
          type: DT_RESOURCE
          type: DT_RESOURCE
        }
      }
    }
  }

...

node {
    name: "serving_default_input"
    op: "Placeholder"
    attr {
      key: "shape"
      value {
        shape {
          dim {
            size: -1
          }
          dim {
            size: 60
          }
        }
      }
    }
    attr {
      key: "dtype"
      value {
        type: DT_FLOAT
      }
    }
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: -1
            }
            dim {
              size: 60
            }
          }
        }
      }
    }
  }

...

signature_def {
  key: "serving_default"
  value {
    inputs {
      key: "input"
      value {
        name: "serving_default_input:0"
        dtype: DT_FLOAT
        tensor_shape {
          dim {
            size: -1
          }
          dim {
            size: 60
          }
        }
      }
    }
    outputs {
      key: "output"
      value {
        name: "StatefulPartitionedCall:0"
        dtype: DT_FLOAT
        tensor_shape {
          dim {
            size: -1
          }
          dim {
            size: 3
          }
        }
      }
    }
    method_name: "tensorflow/serving/predict"
  }
}

Я обнаружил правильный ввод serving_default_input и вывод StatefulPartitionedCall

Обновлен Java код:

float[] fa = //Data you are passing to your model

List<Tensor<?>> tensor = runner.feed("serving_default_input", Tensor.create(fa))
    .fetch("StatefulPartitionedCall").run();

Tensor<Float> t1 = tensor.get(0).expect(Float.class);
float[][] vector = t1.copyTo(new float[1][3]);
for (float[] f : vector) {
  for (float ff : f) {
    System.out.println("res: " + ff);
  }
}

1 Ответ

1 голос
/ 07 мая 2020

Лучший вариант - динамически извлекать эти имена из сигнатур модели и передавать их в вашу модель для вывода.

Чтобы увидеть в Java, каков список входов / выходов вашей сохраненной модели, вы можете получить MetaGraphDef из SavedModelBundle, как описано здесь: Tensorflow 2.0 & Java API . (вы также можете дважды проверить, используя утилиту командной строки [saved_model_cli][1]).

Но имейте в виду, что в TF2.x есть ошибка, когда дело доходит до функциональных моделей, когда TF переходит к некоторому недокументированному изменению имени при кодировании сигнатур входов / выходов, как описано здесь .

Кроме того, вы можете взглянуть на следующую версию TF Java, которая изначально поддерживает версии TF2.x, но в настоящий момент доступна только в виде снимков. .

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...