Tensorflow 2.0 & Java API - PullRequest
       13

Tensorflow 2.0 & Java API

0 голосов
/ 15 апреля 2020

Я играю с TensorFlow, и обработка бэкэнда должна происходить в Java. Я взял одну из моделей из https://developers.google.com/machine-learning/crash-course и сохранил ее с помощью tf.saved_model.save (my_model, "house_price_median_income") (используя контейнер docker). Я скопировал модель и загрузил ее в Java (используя материал 2.0, созданный из исходного кода, потому что я на windows). Я могу загрузить модель и запустить ее:

   try (SavedModelBundle model = SavedModelBundle.load("./house_price_median_income", "serve")) {
    try (Session session = model.session()) {
        Session.Runner runner = session.runner();
        float[][] in = new float[][]{ {2.1518f} } ;

        Tensor<?> jack = Tensor.create(in);
        runner.feed("serving_default_layer1_input", jack);

        float[][] probabilities = runner.fetch("StatefulPartitionedCall").run().get(0).copyTo(new float[1][1]);

        for (int i = 0; i < probabilities.length; ++i) {
            System.out.println(String.format("-- Input #%d", i));
            for (int j = 0; j < probabilities[i].length; ++j) {
              System.out.println(String.format("Class %d - %f", i, probabilities[i][j]));
            }
          }
    }
 }

Выше указано жестко для ввода и вывода, но я хочу иметь возможность прочитать модель и предоставить некоторую информацию, чтобы конечный пользователь мог выбрать вход и вывод, et c.

Я могу получить входы и выходы с помощью команды python: save_model_cli show --dir ./house_price_median_income --all

Что я хочу сделать, это получить входы и выходы через Java, поэтому моему коду не нужно выполнять скрипт python, чтобы получить их. Я могу получить операции через:

 Graph graph = model.graph();
    Iterator<Operation> itr = graph.operations();
    while (itr.hasNext()) {
        GraphOperation e = (GraphOperation)itr.next();
        System.out.println(e);

И это выводит и входы, и выходы как «операции», НО, как мне узнать, что это вход и \ или выход? Инструмент python использует SignatureDef, но, похоже, его вообще нет в TensorFlow 2.0 java. Я упускаю что-то очевидное или просто отсутствует в библиотеке TensforFlow 2.0 Java?

1 Ответ

0 голосов
/ 15 апреля 2020

Что вам нужно сделать, это прочитать метаданные SavedModelBundle как MetaGraphDef, оттуда вы можете получить входные и выходные имена из SignatureDef, как в Python.

В TF Java 1. * (т. Е. Клиент, который вы используете в своем примере), определения прототипа не доступны из артефакта tensorflow, вам нужно добавить зависимость к org.tensorflow:proto как ну и десериализовать результат SavedModelBundle.metaGraphDef() в MetaGraphDef proto.

В TF Java 2. * (новый клиент фактически доступен только как снимки с здесь ), протосы присутствуют сразу, поэтому вы можете просто вызвать эту строку, чтобы получить право SignatureDef:

savedModel.metaGraphDef().signatureDefMap.getValue("serving_default")
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...