Проблемы с сохранением и восстановлением тензор потока RNN - PullRequest
0 голосов
/ 20 ноября 2018

Я написал последовательную модель (RNN) в Python, используя Tensorflow, которую я обучил большим объемам данных, которые мы используем в рамках более крупного проекта для анализа поведения взаимодействия с клиентами в кампании.

Характер проекта таков, что новые данные становятся доступными каждую неделю, и мы хотим иметь возможность включать их в существующую модель.

С этой целью, как часть начальной фазы обучения, я сохраняюмодель выглядит следующим образом:

# Initialize the variables (i.e. assign their 
default value)
init = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Start training
with tf.Session() as sess:

    # Run the initializer
    sess.run(init)

    # Train using batches
    # ...
    # ...

    print("Optimization Finished!")

    # Calculate and display accuracy on test set
    # ...

    # Save the variables to disk.

    save_name = '{0}/{1}_seqmdl_session'.format(outputdir, datasetname)
    save_path = saver.save(sess, save_name, write_meta_graph=True)
    print("Model saved in path: '{0}'".format(save_path))

    hypersave_name = '{0}/{1}_seqmdl_hyperparams.dat'.format(outputdir, datasetname)
    dumpHyperParams(hypersave_name)
    print("Hyper parameters saved in: '{0}'".format(hypersave_name))

Это выполняется в отдельном скрипте Python.

Затем на более позднем этапе я использую сценарий Jupyter Notebook для восстановления модели с целью либо 1) включить новые данные, либо 2) запустить прогноз:

model_path = '{0}/{1}_seqmdl_session'.format(outputdir, datasetname)

tf.reset_default_graph()
graph = tf.Graph()
with graph.as_default():
    saver = tf.train.import_meta_graph(model_path + '.meta')

with tf.Session(graph=graph) as sess:

    sess.run(tf.global_variables_initializer())

    # Restore latest checkpoint
    saver.restore(sess, model_path)

Токажется, работает, как я получаю:

INFO: tenorflow: восстановление параметров из данных / Live / Week2 / quickmails_2018-03-01-2018_04-09_quickmail_deliveries_merged_encrypted_ot_seqmdl_session

ОднакоКогда я пытаюсь запустить прогноз, скажем, я получаю следующую ошибку:

InvalidArgumentError: Количество пакетов 'then' должно соответствовать размеру 'cond', но при этом видно: 121680 против243360 [[Узел: rnn / cond_1 / cond / Select_1 = Выбрать [T = DT_FLOAT, _device = "/ job: localhost / replica: 0 / task: 0 / device: CPU: 0"] (rnn / cond_1 / cond / Select/ Switch_1, rnn / cond_1 / cond / Select_1 / Switch, rnn / cond_1 / cond / Select_1 / Switch_1)]]

И я понятия не имею, как это исправить.Кто-нибудь может пролить свет на это, пожалуйста?

...