Seq2Seq: Как декодировать во время вывода, используя указатель сети и фиксированный словарь - PullRequest
0 голосов
/ 04 мая 2019

Я строю модель Seq2Seq генератора текстовых сводок, которая использует сеть указателей для выборки из входных данных с вероятностью p, в противном случае она декодирует следующее целевое слово, используя фиксированный словарь.Как я могу написать код для декодирования пакета примеров во время вывода без цикла для каждого примера?

    def predict_batch(self, X):
        assert self.embeddings, "Call self.set_embeddings_layer first"
        X = self.embeddings(X)
        enc_states, h1, h2 = self.encoder(X)
        input_tokens = tf.convert_to_tensor([self.start_token] * X.shape[0])

        # put last encoder state as attention vec at start
        c_vec = h1
        outputs = []

        for _ in range(self.max_len):
            dec_input = self.embeddings(input_tokens)
            decoded_state, h1, h2 = self.decoder(dec_input, c_vec, [h1, h2])
            c_vec, _, pointer_prob  = self.attention(enc_states, 
                                                     decoded_state)

            # Compute switch probability to decide if to extract the next
            # word token with a pointer network or a fixed vocabulary
            switch_probs = self.pointer_switch(h1, c_vec)

            ...

После того, как я вычислю вероятности переключения, мне нужно выполнить другой код, основанный на этих вероятностях.

Например, если switch_probs [0,2, 0,8, 0,5] и я сгенерировал несколько случайных чисел, таких как [0,4, 0,7, 0,6], мне нужно выполнить функцию A для примеров [0,2] и функцию BНапример, [1].

Есть ли способ избежать зацикливания для каждого примера и сделать это с помощью какого-нибудь эффективного API Tensorflow?

Вот пример того, что я пытаюсь выполнить: https://arxiv.org/pdf/1704.04368.pdf

...