Загрузка операций из сохраненного классификатора TensorFlow RandomForest - PullRequest
0 голосов
/ 11 мая 2018

Я обучил классификатор случайных лесов TF, подобный следующему коду:

X = tf.placeholder(tf.float32, shape=[None, num_features])
Y = tf.placeholder(tf.int32, shape=[None])

hparams = tensor_forest.ForestHParams(num_classes=num_classes,
                                  num_features=num_features,
                                  num_trees=num_trees).fill()

forest_graph = tensor_forest.RandomForestGraphs(hparams)
train_op = forest_graph.training_graph(X, Y)
loss_op = forest_graph.training_loss(X, Y)
infer_op, _, _ = forest_graph.inference_graph(X)
correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y,tf.int64))
accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
init_vars = tf.group(tf.global_variables_initializer(),
resources.initialize_resources(resources.shared_resources()))



with tf.Session() as sess:
    sess.run(init_vars)
    saver = tf.train.Saver()

    for i in range(1, 100):
        for batch_x, batch_y in render_batch(batch_size):
        _, l = sess.run([train_op, loss_op], feed_dict={X: batch_x, Y: batch_y})
        acc = sess.run(accuracy_op, feed_dict={X: batch_x, Y: batch_y})
        print('Step %i, Loss: %f, Acc: %f' % (i, l, acc))
        if acc >= 0.87:
            print("Stopping and saving")
            save_path = saver.save(sess, model_path)
            print("Model saved in file: %s" % save_path)    
            break 

Теперь я хочу перезагрузить мою модель и использовать ее для прогнозирования невидимых данных, таких как:

with graph.as_default():
session_conf = tf.ConfigProto()
sess = tf.Session(config = session_conf)
with sess.as_default():
    saver = tf.train.import_meta_graph("{}.meta".format(model_path))
    saver.restore(sess,checkpoint_file)
    accuracy_op = graph.get_operation_by_name("accuracy_op").outputs[0]
    print(sess.run(accuracy_op, feed_dict={X: x_test, Y: y_test}))

Однако я получаю следующее сообщение об ошибке:

KeyError: "The name 'accuracy_op' refers to an Operation not in the graph."

У меня вопрос: как я могу сохранить свою модель, чтобы при ее перезагрузке я мог импортировать те операции, которые определены выше, и использовать их с невидимыми данными?

Спасибо!

1 Ответ

0 голосов
/ 11 мая 2018

Поскольку вы используете get_operation_by_name, вы должны были назвать оп accuracy_op. Вы можете сделать это, используя tf.identity:

 accuracy_op = tf.identity(tf.reduce_mean(tf.cast(correct_prediction, tf.float32)), 'accuracy_op')

Я вижу, что вы используете тензоры X и Y без загрузки из нового графика. Назовите тензоры в исходном коде, а затем перезагрузите, используя get_tensor_by_name()

...