Я пытаюсь восстановить текущую нейронную сеть из файла .cpkt. Мой код для восстановления сети:
graph = tf.Graph()
with graph.as_default():
X = tf.placeholder(tf.float32, [1, n_steps, n_inputs])
cell = tf.contrib.rnn.OutputProjectionWrapper(
tf.contrib.rnn.BasicRNNCell(num_units=n_neurons, activation=tf.nn.relu),
output_size=n_outputs
)
outputs, states = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)
saver = tf.train.Saver()
with tf.Session(graph=graph) as sess:
name = "rnnMonthly2"
saver.restore(sess, os.getcwd() + "//RNNConfigs//" + name + "//" + name + ".cpkt")
X_batch = priceArrayToRNNFormat(getPriceArray(symbol="IBM")[-30:0])
y_val = sess.run(feed_dict={X: X_batch})
print(y_val)
Для справки, текстовый файл контрольных точек говорит, что путь к файлам контрольных точек следующий:
model_checkpoint_path: "/home/john/Python/StockProject//RNNConfigs//rnnMonthly2//rnnMonthly2.cpkt"
all_model_checkpoint_paths: "/home/john/Python/StockProject//RNNConfigs//rnnMonthly2//rnnMonthly2.cpkt"
По этой причине я бы подумал, что с учетом пути к файлу, который я указал в saver.restore, модель должна быть восстановлена должным образом. Однако, когда я запускаю код, я получаю следующее сообщение:
Traceback (most recent call last):
File "/home/john/Python/StockProject/monthlyRnn1.py", line 151, in <module>
saver.restore(sess, os.getcwd() + "//RNNConfigs//" + name + "//" + name + ".cpkt.index")
File "/home/john/.local/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1538, in restore
+ compat.as_text(save_path))
ValueError: The passed save_path is not a valid checkpoint: /home/john/Python/StockProject//RNNConfigs//rnnMonthly2//rnnMonthly2.cpkt.index
В чем причина этой ошибки и что я могу сделать, чтобы исправить ее? Для справки, это код, который я использовал для обучения и сохранения сети:
saver = tf.train.Saver()
init = tf.global_variables_initializer()
with tf.Session() as sess:
mse_list = []
init.run()
for iteration in range(n_iterations):
dataOrig = allStocksDict[list(allStocksDict.keys())[iteration]]
X_batch, y_batch = priceArrayToRNNFormat(dataOrig)
print(X_batch, y_batch)
print(X_batch, y_batch)
sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
mse = loss.eval(feed_dict={X: X_batch, y: y_batch})
print(iteration, "\tMSE", mse)
mse_list.append(mse)
print(mse_list)
name = "rnnMonthly2"
saver.save(sess, os.getcwd() + "//RNNConfigs//" + name + "//" + name + ".cpkt")