TensorFlow не спасет рекуррентную нейронную сеть - PullRequest
0 голосов
/ 11 марта 2019

Я пишу рекуррентную нейронную сеть в Tensor Flow для Python. Вот код для обучения:

saver = tf.train.Saver()
init = tf.global_variables_initializer()

with tf.Session() as sess:
    mse_list = []
    init.run()
    for iteration in range(n_iterations):
        print(iteration)
        dataOrig = allStocksDict[list(allStocksDict.keys())[iteration]]
        X_batch, y_batch = priceArrayToRNNFormat(dataOrig)
        print(X_batch, y_batch)
        sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
        mse = loss.eval(feed_dict={X: X_batch, y: y_batch})
        print(iteration, "\tMSE", mse)
        mse_list.append(mse)
        print(datetime.datetime.now())
    print(mse_list)
    mse_sqrt = []
    for mse in mse_list:
        mse_sqrt.append(math.sqrt(mse))
    plt.plot(mse_sqrt)
    plt.ylabel("sqrt of mean squared error")
    plt.show()
    name = "iter1_rnn_monthly"
    saver.save(sess, os.getcwd() + "//RNNConfigs//" + name + "./" + name + ".cpkt")

Когда я запускаю это, сеть вычисляет градиенты и оптимизирует, выполняя весь код. Однако, как только он достигает строки для сохранения файла контрольной точки, программа ничего не делает и не останавливается - она ​​просто остается статичной, ничего не делая, вместо сохранения сети.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...