Tensorflow: «недопустимая контрольная точка» - PullRequest
0 голосов
/ 10 марта 2019

Я пытаюсь восстановить текущую нейронную сеть из файла .cpkt. Мой код для восстановления сети:

graph = tf.Graph()
with graph.as_default():
    X = tf.placeholder(tf.float32, [1, n_steps, n_inputs])
    cell = tf.contrib.rnn.OutputProjectionWrapper(
        tf.contrib.rnn.BasicRNNCell(num_units=n_neurons, activation=tf.nn.relu),
        output_size=n_outputs
    )
    outputs, states = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)
    saver = tf.train.Saver()
with tf.Session(graph=graph) as sess:
    name = "rnnMonthly2"
    saver.restore(sess, os.getcwd() + "//RNNConfigs//" + name + "//" + name + ".cpkt")
    X_batch = priceArrayToRNNFormat(getPriceArray(symbol="IBM")[-30:0])
    y_val = sess.run(feed_dict={X: X_batch})
    print(y_val)

Для справки, текстовый файл контрольных точек говорит, что путь к файлам контрольных точек следующий:

model_checkpoint_path: "/home/john/Python/StockProject//RNNConfigs//rnnMonthly2//rnnMonthly2.cpkt"
all_model_checkpoint_paths: "/home/john/Python/StockProject//RNNConfigs//rnnMonthly2//rnnMonthly2.cpkt"

По этой причине я бы подумал, что с учетом пути к файлу, который я указал в saver.restore, модель должна быть восстановлена ​​должным образом. Однако, когда я запускаю код, я получаю следующее сообщение:

Traceback (most recent call last):
  File "/home/john/Python/StockProject/monthlyRnn1.py", line 151, in <module>
    saver.restore(sess, os.getcwd() + "//RNNConfigs//" + name + "//" + name + ".cpkt.index")
  File "/home/john/.local/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1538, in restore
    + compat.as_text(save_path))
ValueError: The passed save_path is not a valid checkpoint: /home/john/Python/StockProject//RNNConfigs//rnnMonthly2//rnnMonthly2.cpkt.index

В чем причина этой ошибки и что я могу сделать, чтобы исправить ее? Для справки, это код, который я использовал для обучения и сохранения сети:

saver = tf.train.Saver()
init = tf.global_variables_initializer()

with tf.Session() as sess:
    mse_list = []
    init.run()
    for iteration in range(n_iterations):
        dataOrig = allStocksDict[list(allStocksDict.keys())[iteration]]
        X_batch, y_batch = priceArrayToRNNFormat(dataOrig)
        print(X_batch, y_batch)
        print(X_batch, y_batch)
        sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
        mse = loss.eval(feed_dict={X: X_batch, y: y_batch})
        print(iteration, "\tMSE", mse)
        mse_list.append(mse)
    print(mse_list)
    name = "rnnMonthly2"
    saver.save(sess, os.getcwd() + "//RNNConfigs//" + name + "//" + name + ".cpkt")
...