Я написал последовательную модель (RNN) в Python, используя Tensorflow, которую я обучил большим объемам данных, которые мы используем в рамках более крупного проекта для анализа поведения взаимодействия с клиентами в кампании.
Характер проекта таков, что новые данные становятся доступными каждую неделю, и мы хотим иметь возможность включать их в существующую модель.
С этой целью, как часть начальной фазы обучения, я сохраняюмодель выглядит следующим образом:
# Initialize the variables (i.e. assign their
default value)
init = tf.global_variables_initializer()
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Start training
with tf.Session() as sess:
# Run the initializer
sess.run(init)
# Train using batches
# ...
# ...
print("Optimization Finished!")
# Calculate and display accuracy on test set
# ...
# Save the variables to disk.
save_name = '{0}/{1}_seqmdl_session'.format(outputdir, datasetname)
save_path = saver.save(sess, save_name, write_meta_graph=True)
print("Model saved in path: '{0}'".format(save_path))
hypersave_name = '{0}/{1}_seqmdl_hyperparams.dat'.format(outputdir, datasetname)
dumpHyperParams(hypersave_name)
print("Hyper parameters saved in: '{0}'".format(hypersave_name))
Это выполняется в отдельном скрипте Python.
Затем на более позднем этапе я использую сценарий Jupyter Notebook для восстановления модели с целью либо 1) включить новые данные, либо 2) запустить прогноз:
model_path = '{0}/{1}_seqmdl_session'.format(outputdir, datasetname)
tf.reset_default_graph()
graph = tf.Graph()
with graph.as_default():
saver = tf.train.import_meta_graph(model_path + '.meta')
with tf.Session(graph=graph) as sess:
sess.run(tf.global_variables_initializer())
# Restore latest checkpoint
saver.restore(sess, model_path)
Токажется, работает, как я получаю:
INFO: tenorflow: восстановление параметров из данных / Live / Week2 / quickmails_2018-03-01-2018_04-09_quickmail_deliveries_merged_encrypted_ot_seqmdl_session
ОднакоКогда я пытаюсь запустить прогноз, скажем, я получаю следующую ошибку:
InvalidArgumentError: Количество пакетов 'then' должно соответствовать размеру 'cond', но при этом видно: 121680 против243360 [[Узел: rnn / cond_1 / cond / Select_1 = Выбрать [T = DT_FLOAT, _device = "/ job: localhost / replica: 0 / task: 0 / device: CPU: 0"] (rnn / cond_1 / cond / Select/ Switch_1, rnn / cond_1 / cond / Select_1 / Switch, rnn / cond_1 / cond / Select_1 / Switch_1)]]
И я понятия не имею, как это исправить.Кто-нибудь может пролить свет на это, пожалуйста?