Я пишу рекуррентную нейронную сеть в Tensor Flow для Python. Вот код для обучения:
saver = tf.train.Saver()
init = tf.global_variables_initializer()
with tf.Session() as sess:
mse_list = []
init.run()
for iteration in range(n_iterations):
print(iteration)
dataOrig = allStocksDict[list(allStocksDict.keys())[iteration]]
X_batch, y_batch = priceArrayToRNNFormat(dataOrig)
print(X_batch, y_batch)
sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
mse = loss.eval(feed_dict={X: X_batch, y: y_batch})
print(iteration, "\tMSE", mse)
mse_list.append(mse)
print(datetime.datetime.now())
print(mse_list)
mse_sqrt = []
for mse in mse_list:
mse_sqrt.append(math.sqrt(mse))
plt.plot(mse_sqrt)
plt.ylabel("sqrt of mean squared error")
plt.show()
name = "iter1_rnn_monthly"
saver.save(sess, os.getcwd() + "//RNNConfigs//" + name + "./" + name + ".cpkt")
Когда я запускаю это, сеть вычисляет градиенты и оптимизирует, выполняя весь код. Однако, как только он достигает строки для сохранения файла контрольной точки, программа ничего не делает и не останавливается - она просто остается статичной, ничего не делая, вместо сохранения сети.