Tensorflow, служащий в Го - PullRequest
0 голосов
/ 10 мая 2018

Я пытаюсь запустить модель keras в Go. Сначала я тренирую модель на питоне:

import keras as krs
from keras import backend as K
import tensorflow as tf

sess = tf.Session()
K.set_session(sess)
K._LEARNING_PHASE = tf.constant(0)
K.set_learning_phase(0)

m1 = krs.models.Sequential()
m1.Add(krs.layers.Dense(..., name="inputNode"))
...
m1.Add(krs.layers.Dense(..., activation="softmax", name="outputNode"))
m1.compile(...)
m1.fit(...)

Тогда я понимаю, что рекомендуется заморозить модель - преобразовать заполнитель в константы.

saver = tf.train.Saver()
tf.train.write_graph(sess.graph_def, '.', 'my_model.pbtxt')
saver.save(sess, save_path="my_model.ckpt")

from tensorflow.python.tools import freeze_graph
from tensorflow.python.tools import optimize_for_inference_lib

freeze_graph.freeze_graph(input_graph = 'my_model.pbtxt',  input_saver = "",
                 input_binary = False, input_checkpoint = "my_model.ckpt", output_node_names = "outputNode/Softmax",
                 restore_op_name = "save/restore_all", filename_tensor_name = "save/Const:0",
                 output_graph = "frozen_my_model.pb", clear_devices = True, initializer_nodes = "")

При попытке использовать замороженную модель на Голанге:

model, err := tf.LoadSavedModel("frozen_my_model.pb", []string{"serve"}, nil)

Возвращает ошибку, что тег не найден SavedModel load for tags { serve }; Status: fail.

Поэтому мои вопросы:

  1. Как заморозить модель в Python, а затем загрузить ее в Go
  2. Я делаю это, чтобы ускорить вывод в Go - правильно ли это замораживание модели улучшат скорость вывода?
  3. Я заметил, что существует другая функция optimize_for_inference, как бы это реализовано в вышеуказанных настройках?

1 Ответ

0 голосов
/ 10 мая 2018

Вы должны «пометить» обученную модель, используя

    # Create a builder to export the model
    builder = tf.saved_model.builder.SavedModelBuilder("export")
    # Tag the model in order to be capable of restoring it specifying the tag set
    builder.add_meta_graph_and_variables(sess, ["tag"])
    builder.save()

. После этого вы можете загрузить ее в Go.

Однако, более удобным решением является использование * 1006.* tfgo

Как вы можете видеть в README, есть код для обоих: train in python и inference in Go.Я сообщу вам здесь:

Python: обучаем LeNet на MNIST (пример)

import sys
import tensorflow as tf
from dytb.inputs.predefined.MNIST import MNIST
from dytb.models.predefined.LeNetDropout import LeNetDropout
from dytb.train import train

def main():
    """main executes the operations described in the module docstring"""
    lenet = LeNetDropout()
    mnist = MNIST()

    info = train(
        model=lenet,
        dataset=mnist,
        hyperparameters={"epochs": 2},)

    checkpoint_path = info["paths"]["best"]

    with tf.Session() as sess:
        # Define a new model, import the weights from best model trained
        # Change the input structure to use a placeholder
        images = tf.placeholder(tf.float32, shape=(None, 28, 28, 1), name="input_")
        # define in the default graph the model that uses placeholder as input
        _ = lenet.get(images, mnist.num_classes)

        # The best checkpoint path contains just one checkpoint, thus the last is the best
        saver = tf.train.Saver()
        saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path))

        # Create a builder to export the model
        builder = tf.saved_model.builder.SavedModelBuilder("export")
        # Tag the model in order to be capable of restoring it specifying the tag set
        builder.add_meta_graph_and_variables(sess, ["tag"])
        builder.save()

    return 0


if __name__ == '__main__':
    sys.exit(main())

Go: вывод

package main

import (
        "fmt"
        tg "github.com/galeone/tfgo"
        tf "github.com/tensorflow/tensorflow/tensorflow/go"
)

func main() {
        model := tg.LoadModel("test_models/export", []string{"tag"}, nil)

        fakeInput, _ := tf.NewTensor([1][28][28][1]float32{})
        results := model.Exec([]tf.Output{
                model.Op("LeNetDropout/softmax_linear/Identity", 0),
        }, map[tf.Output]*tf.Tensor{
                model.Op("input_", 0): fakeInput,
        })

        predictions := results[0].Value().([][]float32)
        fmt.Println(predictions)
}
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...