Обучающий классификатор текста пространственно-трансформеров - пример минимального обучения - PullRequest
0 голосов
/ 17 июня 2020

После ряда экспериментов по линиям, указанным в этих примерах:

https://colab.research.google.com/github/explosion/spacy-pytorch-transformers/blob/master/examples/Spacy_Transformers_Demo.ipynb

https://github.com/explosion/spacy-transformers/blob/master/examples/train_textcat.py

Я обнаружил, что не могу наблюдать эффект любого обучения при вызове nlp.update на моделях пространственного преобразователя. Я пробовал с en_trf_bertbaseuncased_lg, как показано ниже, и с моделью en_trf_distilbertbaseuncased_lg безуспешно. Однако я могу получить классификацию текста с помощью просторных примеров TextCategorizer и LSTM, которые работают.

Следовательно, я хотел бы спросить, что я мог бы сделать, чтобы изменить приведенный ниже код, чтобы получить результат менее 1.0 для «THE_POSITIVE_LABEL» при вызове do c .cats для этого тестового предложения. В настоящее время он работает без ошибок, но всегда возвращает 1.0 для оценки. Я попытался использовать этот пример после запуска правильного набора тренировок и наблюдения за идентичными P, R, F, которые оценивают значение потерь, которое просто прыгает вокруг каждой оценки. Исправленная версия может затем служить в качестве простой проверки работоспособности.

import spacy
from collections import Counter

nlp = spacy.load('en_trf_bertbaseuncased_lg')

textcat = nlp.create_pipe(
      "trf_textcat",
    config={
        "architecture": "softmax_class_vector", # have also tried "softmax_last_hidden" with "words_per_batch" like in one of the examples
        'token_vector_width': 768  # added as otherwise it complains about textcat config not having 'token_vector_width'
    }
)

textcat.add_label("THE_POSITIVE_LABEL")

nlp.add_pipe(textcat, last=True)

nlp.begin_training() # added as otherwise it says trf_textcat has no model when we call doc.cats

print(nlp("an example of a document that does not  match the label").cats)

#{'THE_POSITIVE_LABEL': 1.0} is printed

optimizer = nlp.resume_training()
optimizer.alpha = 0.001
optimizer.trf_weight_decay = 0.005
optimizer.L2 = 0.0
optimizer.trf_lr = 2e-5

losses = Counter()

texts = ['an example of a document that does not  match the label',]

annotations = [{'THE_POSITIVE_LABEL': 0.},]

nlp.update(texts, annotations, sgd=optimizer, drop=0.1, losses=losses)

print(nlp("an example of a document that does not  match the label").cats)

#{'THE_POSITIVE_LABEL': 1.0} is again printed
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...