Как я могу загрузить частично подготовленную модель Pytorch? - PullRequest
2 голосов
/ 14 апреля 2020

Я пытаюсь запустить модель pytorch в задаче классификации предложений. Поскольку я работаю с медицинскими записями, я использую ClinicalBert (https://github.com/kexinhuang12345/clinicalBERT) и хотел бы использовать его предварительно обученные веса. К сожалению, модель ClinicalBert только классифицирует текст в 1 двоичную метку, тогда как у меня 281 двоичная метка. Поэтому я пытаюсь реализовать этот код https://github.com/kaushaltrivedi/bert-toxic-comments-multilabel/blob/master/toxic-bert-multilabel-classification.ipynb, где конечный классификатор после Bert имеет длину 281.

Как я могу загрузить предварительно обученные веса Берта из модели ClinicalBert без загрузки весов классификации?

Наивно пытаясь загрузить веса из предварительно обученных весов ClinicalBert, я получаю следующую ошибку:

size mismatch for classifier.weight: copying a param with shape torch.Size([2, 768]) from checkpoint, the shape in current model is torch.Size([281, 768]).
size mismatch for classifier.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([281]).

В настоящее время я пытался заменить функцию from_pretrained из пакета pytorch_pretrained_bert и вывести веса и смещения классификатора следующим образом:

def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs):
    ...
    if state_dict is None:
        weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
        state_dict = torch.load(weights_path, map_location='cpu')
    state_dict.pop('classifier.weight')
    state_dict.pop('classifier.bias')
    old_keys = []
    new_keys = []
    ...

И я получаю следующее сообщение об ошибке: INFO - моделирование_диагностики - Веса BertForMultiLabelSequenceClassification, не инициализированной из предварительно обученной модели: ['classifier.weight', 'classifier.bias']

В конце я хотел бы загрузить вложения bert из предварительно обученных весов ClinicalBert и иметь верхний классификатор веса инициализируются случайным образом.

1 Ответ

1 голос
/ 14 апреля 2020

Снятие ключей в состоянии dict перед загрузкой - хорошее начало. Предполагая, что вы используете nn.Module.load_state_dict для загрузки предварительно обученных весов, вам также потребуется установить аргумент strict=False, чтобы избежать ошибок из-за непредвиденных или отсутствующих ключей. Это будет игнорировать записи в state_dict, которых нет в модели (неожиданные ключи), и, что более важно для вас, оставит отсутствующие записи с их инициализацией по умолчанию (отсутствующие ключи). В целях безопасности вы можете проверить возвращаемое значение метода, чтобы убедиться, что рассматриваемые веса являются частью отсутствующих ключей, и что нет никаких неожиданных ключей.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...