Как получить информацию о выравнивании или внимании для переводов, выполненных моделью хаба горелки? - PullRequest
1 голос
/ 22 марта 2020

Концентратор резака предоставляет модели с предварительной подготовкой, такие как: https://pytorch.org/hub/pytorch_fairseq_translation/

Эти модели можно использовать в python или интерактивно с CLI. С CLI возможно получить выравнивания, с флагом --print-alignment. Следующий код работает в терминале, после установки fairseq (и pytorch)

curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf -
MODEL_DIR=wmt14.en-fr.fconv-py
fairseq-interactive \
    --path $MODEL_DIR/model.pt $MODEL_DIR \
    --beam 5 --source-lang en --target-lang fr \
    --tokenizer moses \
    --bpe subword_nmt --bpe-codes $MODEL_DIR/bpecodes \ 
    --print-alignment

В python можно указать ключевое слово args verbose и print_alignment:

import torch

en2fr = torch.hub.load('pytorch/fairseq', 'transformer.wmt14.en-fr', tokenizer='moses', bpe='subword_nmt')

fr = en2fr.translate('Hello world!', beam=5, verbose=True, print_alignment=True)

Однако это будет выводить выравнивание только в виде сообщения регистрации. А для fairseq 0.9 он, похоже, не работает и выдает сообщение об ошибке ( проблема ).

Есть ли способ получить информацию о выравнивании (или, возможно, даже матрицу полного внимания) из python код?

1 Ответ

1 голос
/ 22 марта 2020

Я просмотрел кодовую базу fairseq и нашел хакерский способ вывода информации о выравнивании. Поскольку это требует редактирования самого исходного кода fairseq, я не думаю, что это приемлемое решение. Но, возможно, это кому-то поможет (мне все еще очень интересен ответ о том, как это сделать правильно).

Отредактируйте функцию sample () и переписайте инструкцию return. Вот целая функция (чтобы помочь вам найти ее лучше в коде), но следует изменить только последнюю строку:

def sample(self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs) -> List[str]:
    if isinstance(sentences, str):
        return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0]
    tokenized_sentences = [self.encode(sentence) for sentence in sentences]
    batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs)
    return list(zip([self.decode(hypos[0]['tokens']) for hypos in batched_hypos], [hypos[0]['alignment'] for hypos in batched_hypos]))
...