Я написал RNN, который просматривает абзацы на уровне символов и хотел бы сохранить его для последующего использования.Некоторый код выглядит следующим образом:
cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
batch_size = tf.placeholder(tf.int32, [], name='batch_size')
multi_cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)
init_state = multi_cell.zero_state(batch_size, dtype=tf.float32)
rnn_outputs, final_state = tf.nn.dynamic_rnn(multi_cell, rnn_inputs, initial_state=init_state)
with tf.variable_scope('softmax'):
W = tf.get_variable('W', [state_size, num_classes])
b = tf.get_variable('b', [num_classes], initializer=tf.constant_initializer(0.0))
rnn_outputs = tf.reshape(rnn_outputs, [-1, state_size])
y_reshaped = tf.reshape(y, [-1])
logits = tf.matmul(rnn_outputs, W) + b
predictions = tf.nn.softmax(logits, name="predictions")
total_loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits,
labels=y_reshaped
)
)
train_step = tf.train.AdamOptimizer(learning_rate).minimize(total_loss)
И затем я использую tf.train.Saver()
и saver.save(sess, "path/to/save")
для сохранения моей модели.
Затем я пытаюсь загрузить свою модель в другом скрипте и сгенерировать текстиспользуя приведенный ниже код:
tf.reset_default_graph()
imported_meta = tf.train.import_meta_graph("path/to/save/save_file.meta")
with tf.Session() as sess:
imported_meta.restore(sess, tf.train.latest_checkpoint("path/to/save"))
x = sess.graph.get_tensor_by_name("input_placeholder:0")
batch_size_tensor = sess.graph.get_tensor_by_name("batch_size:0")
predictions = sess.graph.get_tensor_by_name("predictions:0")
state = None
current_char = vocab_to_index[start_char]
for i in range(num_chars):
if state is not None:
feed_dict={batch_size_tensor: batch_size, x: [[current_char]], init_state: state}
else:
feed_dict={batch_size_tensor: batch_size, x: [[current_char]]}
rnn_outputs, state = sess.run(
[predictions, final_state],
feed_dict
)
В основном здесь я хочу ввести символ, затем сгенерировать символ на основе предыдущего и снова.После начального символа final_state
из dynamic_rnn должен быть sess.run()
и вводиться в следующий процесс генерации как init_state
.Однако я не смог найти способ сохранить init_state
и final_state
, определенные в обучающем коде, для загрузки в тестовый код, для этих операций нет аргумента "name", подобного tf.nn.softmax
.
Мне нужен код вроде final_state = sess.graph.get_operation_by_name('final_state')
, чтобы я мог sess.run(final_state)
и передать его как init_state
.
Я пытался использовать tf.add_to_collection("some_name", final_state)
в обучающем коде и tf.get_collection("some_name")
, но ошибка говорит о том, что коллекция "some_name" не может быть найдена в тестовом графе.
Кто-нибудь, кто написал модель генерации текста, сталкивался с этой проблемой на этапе генерации?Или как люди генерируют текст / сохраняют и загружают вывод из dynamic_rnn?
Заранее большое спасибо!