Разбить тензорный поток. While_l oop условно - PullRequest
1 голос
/ 03 августа 2020

Я пытаюсь изменить код, генерирующий образец GPT-2 из вилки nshepperd (https://github.com/nshepperd/gpt-2).

В частности, код ниже, который является частью файла sample.py:

with tf.name_scope('sample_sequence'):
    # Don't feed the last context token -- leave that to the loop below
    # TODO: Would be slightly faster if we called step on the entire context,
    # rather than leaving the last token transformer calculation to the while loop.
    context_output = step(hparams, context[:, :-1])

    def body(past, prev, output):
        next_outputs = step(hparams, prev[:, tf.newaxis], past=past)
        logits = next_outputs['logits'][:, -1, :]  / tf.to_float(temperature)
        if penalize > 0.0:
            logits = penalize_used(logits, output, penalize=penalize)
        if top_p > 0.0:
            logits = top_p_logits(logits, p=top_p, epsilon=epsilon)
        else:
            logits = top_k_logits(logits, k=top_k, epsilon=epsilon)
        samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)
        return [
            tf.concat([past, next_outputs['presents']], axis=-2),
            tf.squeeze(samples, axis=[1]),
            tf.concat([output, samples], axis=1),
        ]

    def cond(*args):
        return True

    _, _, tokens = tf.while_loop(
        cond=cond, body=body,
        maximum_iterations=length,
        loop_vars=[
            context_output['presents'],
            context[:, -1],
            context,
        ],
        shape_invariants=[
            tf.TensorShape(model.past_shape(hparams=hparams, batch_size=batch_size)),
            tf.TensorShape([batch_size]),
            tf.TensorShape([batch_size, None]),
        ],
        back_prop=False,
    )

    return tokens

По сути, я пытаюсь остановить его, как только он сгенерирует токен с заданным значением c, например! EndText !. Однако, поскольку я новичок в тензорном потоке, я очень не уверен, как это сделать, тем более что официальная документация по этому поводу немного скудна. Если я правильно понимаю, мне нужно изменить функцию cond (которую я понимаю на l oop по всем выходам функции body), чтобы она ломалась, если en c .decode (output) == "! EndText! " однако я более или менее не знаю, с чего начать.

1 Ответ

0 голосов
/ 04 августа 2020

Я нашел ответ (благодаря значительной помощи от кого-то, кто более знаком с TF, чем я).

Единственная необходимая модификация - удалить «return True» из cond и вместо этого вставить следующий return оператор

return tf.math.logical_and(tf.not_equal(output[0][-1], tf.cast(X, tf.int32)), tf.not_equal(output[0][-2], tf.cast(Y, tf.int32)))

где X и Y представляют (в данном случае) два токена, которые должны сигнализировать об остановке, если вы хотите меньше или больше токенов (т.е. если ваши токены, которые сигнализируют об остановке менее / более различимы), просто добавьте / удалите термины в приведенном выше операторе logic_and.

...