Ошибка при попытке повторного использования весов для RNN - PullRequest
0 голосов
/ 15 мая 2018

Я пытаюсь повторно использовать двунаправленные веса LSTM для двух очень похожих вычислений, но я получаю сообщение об ошибке и понятия не имею, что я делаю неправильно.У меня есть класс для базового модуля:

class BasicAttn(object):    
    def __init__(self, keep_prob, value_vec_size):    
        self.rnn_cell_fw = rnn_cell.LSTMCell(value_vec_size/2, reuse=True)
        self.rnn_cell_fw = DropoutWrapper(self.rnn_cell_fw, input_keep_prob=self.keep_prob)
        self.rnn_cell_bw = rnn_cell.LSTMCell(value_vec_size/2, reuse=True)
        self.rnn_cell_bw = DropoutWrapper(self.rnn_cell_bw, input_keep_prob=self.keep_prob)

    def build_graph(self, values, values_mask, keys):
        blended_reps = compute_blended_reps()
        with tf.variable_scope('BasicAttn_BRNN', reuse=True):
        (fw_out, bw_out), _ = 
        tf.nn.bidirectional_dynamic_rnn(self.rnn_cell_fw, self.rnn_cell_bw, blended_reps, dtype=tf.float32, scope='BasicAttn_BRNN')                                                      

Затем модуль вызывается при построении графика

    attn_layer_start = BasicAttn(...)
    blended_reps_start = attn_layer_start.build_graph(...)
    attn_layer_end = BasicAttn(...)
    blended_reps_end = attn_layer_end.build_graph(...)

Но я получаю сообщение об ошибке, в котором говорится, что TensorFlow не может повторно использоватьRNNs?

ValueError: Variable QAModel/BasicAttn_BRNN/BasicAttn_BRNN/fw/lstm_cell/kernel does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=tf.AUTO_REUSE in VarScope

кода много, поэтому я обрезал части, которые мне показались ненужными.

1 Ответ

0 голосов
/ 15 мая 2018

reuse=True означает, что переменные были созданы ранее с reuse=False, поэтому каждый tf.get_variable (в вашем случае, абстрагированный от интерфейса LSTM) ожидает, что переменная уже существует.

Чтобы иметь режим, в котором переменные создаются, если они еще не существуют, и используются повторно в других случаях, необходимо установить reuse=tf.AUTO_REUSE (как следует из сообщения об ошибке).

Поэтому замените все вхождения reuse=True на reuse=tf.AUTO_REUSE

Вот документация: https://www.tensorflow.org/api_docs/python/tf/variable_scope

...