Tensorflow - Вопрос о прохождении операций из обученного RNN и генерации текста - PullRequest
0 голосов
/ 02 января 2019

Я написал RNN, который просматривает абзацы на уровне символов и хотел бы сохранить его для последующего использования.Некоторый код выглядит следующим образом:

cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
batch_size = tf.placeholder(tf.int32, [], name='batch_size')
multi_cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)
init_state = multi_cell.zero_state(batch_size, dtype=tf.float32)
rnn_outputs, final_state = tf.nn.dynamic_rnn(multi_cell, rnn_inputs, initial_state=init_state)

with tf.variable_scope('softmax'):
    W = tf.get_variable('W', [state_size, num_classes])
    b = tf.get_variable('b', [num_classes], initializer=tf.constant_initializer(0.0))

rnn_outputs = tf.reshape(rnn_outputs, [-1, state_size])
y_reshaped = tf.reshape(y, [-1])

logits = tf.matmul(rnn_outputs, W) + b
predictions = tf.nn.softmax(logits, name="predictions")

total_loss = tf.reduce_mean(
    tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits, 
        labels=y_reshaped
    )
)
train_step = tf.train.AdamOptimizer(learning_rate).minimize(total_loss)

И затем я использую tf.train.Saver() и saver.save(sess, "path/to/save") для сохранения моей модели.

Затем я пытаюсь загрузить свою модель в другом скрипте и сгенерировать текстиспользуя приведенный ниже код:

tf.reset_default_graph()
imported_meta = tf.train.import_meta_graph("path/to/save/save_file.meta")

with tf.Session() as sess:
    imported_meta.restore(sess, tf.train.latest_checkpoint("path/to/save"))
    x = sess.graph.get_tensor_by_name("input_placeholder:0")
    batch_size_tensor = sess.graph.get_tensor_by_name("batch_size:0")
    predictions = sess.graph.get_tensor_by_name("predictions:0")

    state = None
    current_char = vocab_to_index[start_char]

    for i in range(num_chars):
        if state is not None:
            feed_dict={batch_size_tensor: batch_size, x: [[current_char]], init_state: state}
        else:
            feed_dict={batch_size_tensor: batch_size, x: [[current_char]]}

        rnn_outputs, state = sess.run(
            [predictions, final_state], 
            feed_dict
        )

В основном здесь я хочу ввести символ, затем сгенерировать символ на основе предыдущего и снова.После начального символа final_state из dynamic_rnn должен быть sess.run() и вводиться в следующий процесс генерации как init_state.Однако я не смог найти способ сохранить init_state и final_state, определенные в обучающем коде, для загрузки в тестовый код, для этих операций нет аргумента "name", подобного tf.nn.softmax.

Мне нужен код вроде final_state = sess.graph.get_operation_by_name('final_state'), чтобы я мог sess.run(final_state) и передать его как init_state.

Я пытался использовать tf.add_to_collection("some_name", final_state) в обучающем коде и tf.get_collection("some_name"), но ошибка говорит о том, что коллекция "some_name" не может быть найдена в тестовом графе.

Кто-нибудь, кто написал модель генерации текста, сталкивался с этой проблемой на этапе генерации?Или как люди генерируют текст / сохраняют и загружают вывод из dynamic_rnn?

Заранее большое спасибо!

...