Как применить обученную нейронную сеть? - PullRequest
1 голос
/ 30 мая 2019

[Python 3.7, Tensorflow] Я обучил нейронную сеть.Все работает хорошо, он учится, но как только он закончил обучение, он просто отключается, и прогресс теряется.Теперь я хочу ввести новые данные и вручную посмотреть, насколько хорошо работает сеть.

Я уже возился с

saver = tf.train.Saver()
saver.save(sess, 'model/model.ckpt')

, но это всегда приводит к миледлинный отчет об ошибке, заканчивающийся «Неизвестная ошибка: не удалось переименовать« model / model.ckpt »» и т. д.

Код в контексте выглядит следующим образом:

def train_neural_network(x):
    training_data = generate_training_data() # i cut getting training data since its a bit out of context here, but its basically like mnist data

    prediction = neural_network_model(x) # normal, 3-layer feed forward NN
    cost = tf.reduce_mean( tf.square(prediction - y) )
    optimizer = tf.train.AdamOptimizer(0.01).minimize(cost)

    hm_epochs = 10
    saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        for epoch in range(hm_epochs+1):
            epoch_loss = 0

            for i in range(10):
                epoch_x, epoch_y = training_data

                _, c = sess.run([optimizer, cost], feed_dict = {x: epoch_x, y: epoch_y})

        saver.save(sess, 'model/model.ckpt')

Я пытаюсь вызвать этоОбученная нейронная сеть в основном:

train_neural_network(x)

X, Y = generate_training_data()

prediction = neural_network_model(x)
saver = tf.train.Saver()

with tf.Session() as sess:

     saver.restore(sess, 'model/model.ckpt')
     result = sess.run(prediction, feed_dict={x: X})

print(Y, result)

Пока все в одном файле, но я также могу сделать с двумя отдельными файлами.

Это приводит к ошибке, которая говорит о обычной ошибке питона, состоящей из его пути и заканчивающейся
"... в _do_call повысить тип (e) (node_def, op, message)" перед a,я думаю, что возникает ошибка, связанная с Tensorflow: «Неизвестная ошибка: не удалось переименовать« model / model.ckpt »» и «Вызвано операцией« save_13 / SaveV2 », определенной в:», а затем длинный, длинный путь,около 87 строк, «Неизвестная ошибка» повторяется снова.

Я хочу получить распечатанную этикетку с прогнозируемым выходным сигналом нейронной сети.(строка печати в коде.)

К сожалению, я пока не нашел ничего, что бы работало в различных поисках в Интернете, но я чувствую, что не должно быть слишком сложно заставить это работать.Заранее спасибо.

1 Ответ

1 голос
/ 30 мая 2019

Если вы посмотрите в папку, где ваша модель выводит контрольные точки (/ model), вы должны увидеть 3 отдельных файла для каждого сохранения: model.ckpt-xxx.data, model.ckpt-xxx.index и model.ckpt-xxx.meta, где xxx - это идентификатор контрольной точки, добавленной Tensorflow.

Если вы хотите восстановить определенную контрольную точку, вам также необходимо добавить идентификатор, потому что обычно во время обучения создаются несколько контрольных точек одной и той же сети, чтобы мы могли при необходимости перенастроить сеть позже.

Так что я бы посмотрел в папке модели и дважды проверил имя файла, я думаю, saver.restore(sess, 'model/model.ckpt-0') сработает, если вы создадите только одну контрольную точку.

...