Как использовать tf.contrib.seq2seq в нетерпеливом исполнении? - PullRequest
0 голосов
/ 06 ноября 2018

Я работаю над моделью seq2seq, в которой используется модуль tf.contrib.seq2seq, и я хотел бы сделать ее совместимой как с графиком, так и с нетерпением к выполнению. Насколько я знаю, tf.contrib.seq2seq, к сожалению, пока не работает в нетерпеливом исполнении (поправьте меня, если я ошибаюсь). Поэтому я хотел вставить его с tfe.defun(), но получил следующую ошибку:

RuntimeError: tf.device does not support functions when eager execution is enabled.

в следующей строке:

final_outputs, final_state, final_sequence_lengths = \
                tf.contrib.seq2seq.dynamic_decode(decoder, impute_finished=True

Две части, где я использую tfe.defun():

def train_decoder_model(self, function_encoder_outputs, decoder_inputs, target_sequence_length):
    with tf.name_scope('embedding'):
        #  The embedding layer expects integer instead of one-hot encodings.
        decoder_inputs_ints = tf.argmax(decoder_inputs, axis=-1)
        #  Perform the embedding on the decoder input.
        decoder_embedding = tf.nn.embedding_lookup(self._emb_matrix, decoder_inputs_ints)
    with tf.name_scope('decoder'):
        target_sequence_length = tf.cast(target_sequence_length, tf.int32)
        def graph_decoder():
            helper = tf.contrib.seq2seq.TrainingHelper(decoder_embedding, target_sequence_length)
            decoder = tf.contrib.seq2seq.BasicDecoder(self.decoder_cell, helper,
                                                      function_encoder_outputs, self.decoder_dense)
            final_outputs, final_state, final_sequence_lengths = \
                tf.contrib.seq2seq.dynamic_decode(decoder, impute_finished=True)
            return final_outputs
        graph_func = tfe.defun(graph_decoder)
        final_outputs = graph_func()
    return final_outputs

def loss(self, labels, logits, doc_length):
    def graph_loss():
        masks = tf.sequence_mask(doc_length, tf.reduce_max(doc_length), dtype=tf.float32, name='masks')
        return tf.contrib.seq2seq.sequence_loss(logits, tf.argmax(labels, -1), masks)
    graph_func = tfe.defun(graph_loss)
    return graph_func()

Любая подсказка, как решить эту проблему?

...