Я строю модель Seq2Seq генератора текстовых сводок, которая использует сеть указателей для выборки из входных данных с вероятностью p, в противном случае она декодирует следующее целевое слово, используя фиксированный словарь.Как я могу написать код для декодирования пакета примеров во время вывода без цикла для каждого примера?
def predict_batch(self, X):
assert self.embeddings, "Call self.set_embeddings_layer first"
X = self.embeddings(X)
enc_states, h1, h2 = self.encoder(X)
input_tokens = tf.convert_to_tensor([self.start_token] * X.shape[0])
# put last encoder state as attention vec at start
c_vec = h1
outputs = []
for _ in range(self.max_len):
dec_input = self.embeddings(input_tokens)
decoded_state, h1, h2 = self.decoder(dec_input, c_vec, [h1, h2])
c_vec, _, pointer_prob = self.attention(enc_states,
decoded_state)
# Compute switch probability to decide if to extract the next
# word token with a pointer network or a fixed vocabulary
switch_probs = self.pointer_switch(h1, c_vec)
...
После того, как я вычислю вероятности переключения, мне нужно выполнить другой код, основанный на этих вероятностях.
Например, если switch_probs [0,2, 0,8, 0,5] и я сгенерировал несколько случайных чисел, таких как [0,4, 0,7, 0,6], мне нужно выполнить функцию A для примеров [0,2] и функцию BНапример, [1].
Есть ли способ избежать зацикливания для каждого примера и сделать это с помощью какого-нибудь эффективного API Tensorflow?
Вот пример того, что я пытаюсь выполнить: https://arxiv.org/pdf/1704.04368.pdf