Как сохранить объект, представляющий нейронную сеть, построенную в Tensorflow - PullRequest
0 голосов
/ 06 июня 2019

Я новичок в Tensorflow и играю с кодом на github. Этот код создает класс для нейронной сети, который включает методы для построения сети, формулирования функции потерь, обучения сети, выполнения прогнозирования и т. Д.

Скелетный код будет выглядеть примерно так:

class NeuralNetwork:
    def __init__(...):

    def initializeNN():

    def trainNN():

    def predictNN():

и т.д.. Нейронная сеть построена с использованием Tensorflow, следовательно, определение класса и его методы используют синтаксис тензорного потока.

Теперь в основной части моего скрипта я создаю экземпляр этого класса через

model = NeuralNetwork(...)

и использовать методы model, такие как model.predict, для создания графиков.

Поскольку обучение нейронной сети занимает много времени, я хотел бы сохранить «модель» объекта для будущего использования и с возможностью вызова его методов. Я пробовал мариновать и укроп, но они оба не смогли. Для маринада я получил ошибку:

TypeError: невозможно выбрать объекты _thread.RLock

пока для укропа я получил:

TypeError: невозможно выбрать объекты SwigPyObject

Любые предложения, как я могу сохранить объект и при этом иметь возможность вызывать его методы? Это очень важно, поскольку в будущем мне может потребоваться выполнить прогнозирование для другого набора точек.

Спасибо!

1 Ответ

0 голосов
/ 06 июня 2019

Что вы должны сделать, это следующее:

# Build the graph
model = NeuralNetwork(...)
# Create a train saver/loader object
saver = tf.train.Saver()
# Create a session
with tf.Session() as sess:
    # Train the model in the same way you are doing it currently
    model.train_model()
    # Once you are done training, just save the model definition and it's learned weights
    saver.save(sess, save_path)

И все готово.Затем, когда вы захотите снова использовать модель, вы можете:

# Build the graph
model = NeuralNetwork()
# Create a train saver/loader object
loader = tf.train.Saver()
# Create a session
with tf.Session() as sess:
    # Load the model variables
    loader.restore(sess, save_path)
    # Train the model again for example
    model.train_model()
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...