Я играю с TensorFlow, и обработка бэкэнда должна происходить в Java. Я взял одну из моделей из https://developers.google.com/machine-learning/crash-course и сохранил ее с помощью tf.saved_model.save (my_model, "house_price_median_income") (используя контейнер docker). Я скопировал модель и загрузил ее в Java (используя материал 2.0, созданный из исходного кода, потому что я на windows). Я могу загрузить модель и запустить ее:
try (SavedModelBundle model = SavedModelBundle.load("./house_price_median_income", "serve")) {
try (Session session = model.session()) {
Session.Runner runner = session.runner();
float[][] in = new float[][]{ {2.1518f} } ;
Tensor<?> jack = Tensor.create(in);
runner.feed("serving_default_layer1_input", jack);
float[][] probabilities = runner.fetch("StatefulPartitionedCall").run().get(0).copyTo(new float[1][1]);
for (int i = 0; i < probabilities.length; ++i) {
System.out.println(String.format("-- Input #%d", i));
for (int j = 0; j < probabilities[i].length; ++j) {
System.out.println(String.format("Class %d - %f", i, probabilities[i][j]));
}
}
}
}
Выше указано жестко для ввода и вывода, но я хочу иметь возможность прочитать модель и предоставить некоторую информацию, чтобы конечный пользователь мог выбрать вход и вывод, et c.
Я могу получить входы и выходы с помощью команды python: save_model_cli show --dir ./house_price_median_income --all
Что я хочу сделать, это получить входы и выходы через Java, поэтому моему коду не нужно выполнять скрипт python, чтобы получить их. Я могу получить операции через:
Graph graph = model.graph();
Iterator<Operation> itr = graph.operations();
while (itr.hasNext()) {
GraphOperation e = (GraphOperation)itr.next();
System.out.println(e);
И это выводит и входы, и выходы как «операции», НО, как мне узнать, что это вход и \ или выход? Инструмент python использует SignatureDef, но, похоже, его вообще нет в TensorFlow 2.0 java. Я упускаю что-то очевидное или просто отсутствует в библиотеке TensforFlow 2.0 Java?