Я пытаюсь изменить код, генерирующий образец GPT-2 из вилки nshepperd (https://github.com/nshepperd/gpt-2).
В частности, код ниже, который является частью файла sample.py:
with tf.name_scope('sample_sequence'):
# Don't feed the last context token -- leave that to the loop below
# TODO: Would be slightly faster if we called step on the entire context,
# rather than leaving the last token transformer calculation to the while loop.
context_output = step(hparams, context[:, :-1])
def body(past, prev, output):
next_outputs = step(hparams, prev[:, tf.newaxis], past=past)
logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature)
if penalize > 0.0:
logits = penalize_used(logits, output, penalize=penalize)
if top_p > 0.0:
logits = top_p_logits(logits, p=top_p, epsilon=epsilon)
else:
logits = top_k_logits(logits, k=top_k, epsilon=epsilon)
samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)
return [
tf.concat([past, next_outputs['presents']], axis=-2),
tf.squeeze(samples, axis=[1]),
tf.concat([output, samples], axis=1),
]
def cond(*args):
return True
_, _, tokens = tf.while_loop(
cond=cond, body=body,
maximum_iterations=length,
loop_vars=[
context_output['presents'],
context[:, -1],
context,
],
shape_invariants=[
tf.TensorShape(model.past_shape(hparams=hparams, batch_size=batch_size)),
tf.TensorShape([batch_size]),
tf.TensorShape([batch_size, None]),
],
back_prop=False,
)
return tokens
По сути, я пытаюсь остановить его, как только он сгенерирует токен с заданным значением c, например! EndText !. Однако, поскольку я новичок в тензорном потоке, я очень не уверен, как это сделать, тем более что официальная документация по этому поводу немного скудна. Если я правильно понимаю, мне нужно изменить функцию cond (которую я понимаю на l oop по всем выходам функции body), чтобы она ломалась, если en c .decode (output) == "! EndText! " однако я более или менее не знаю, с чего начать.