Я создал свою собственную модель BertModel, которая выполняет классификацию последовательностей BertQuestionAnswering и Bert. Могу ли я использовать save_pretrained и from_pretrained для этого? - PullRequest
0 голосов
/ 11 апреля 2020

Я сделал многозадачную модель Берта на основе оригинальных моделей Берта. Если я тренируюсь в этой сети, могу ли я сделать: model = BertForMT.from_pretrained(checkpoint) и model.save_pretrained(output_dir) без проблем? Я планирую обучить эту модель одновременно ответам на вопросы, анализу настроений и MNLI. Я соединил примеры сценариев run_squad.py и run_glue.py для своего эксперимента. Поскольку их сценарии выполняют сохранение, а затем загрузку для оценки, я хотел убедиться, что все в порядке, если я слепо сохраню модели, используя функции from_pretrained и save_pretrained.

class BertForMT(BertPreTrainedModel):
    def __init__(self, config):
        super(BertForMT, self).__init__(config)
        self.num_labels = config.num_labels

        self.bert = BertModel(config)
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.sst_classifier = nn.Linear(config.hidden_size, 2)
        self.mnli_classifier = nn.Linear(config.hidden_size, 3)

        self.init_weights()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
        labels=None,
        task='qa'
    ):

        if task == 'qa':
            outputs = self.bert(
                input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
            )

            sequence_output = outputs[0]

            logits = self.qa_outputs(sequence_output)
            start_logits, end_logits = logits.split(1, dim=-1)
            start_logits = start_logits.squeeze(-1)
            end_logits = end_logits.squeeze(-1)

            outputs = (start_logits, end_logits,) + outputs[2:]
            if start_positions is not None and end_positions is not None:
                # If we are on multi-GPU, split add a dimension
                if len(start_positions.size()) > 1:
                    start_positions = start_positions.squeeze(-1)
                if len(end_positions.size()) > 1:
                    end_positions = end_positions.squeeze(-1)
                # sometimes the start/end positions are outside our model inputs, we ignore these terms
                ignored_index = start_logits.size(1)
                start_positions.clamp_(0, ignored_index)
                end_positions.clamp_(0, ignored_index)

                loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
                start_loss = loss_fct(start_logits, start_positions)
                end_loss = loss_fct(end_logits, end_positions)
                total_loss = (start_loss + end_loss) / 2
                outputs = (total_loss,) + outputs

            return outputs  # (loss), start_logits, end_logits, (hidden_states), (attentions)

        elif task == 'sst-2' or task == 'mnli':

            outputs = self.bert(
                input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
            )

            pooled_output = outputs[1]

            pooled_output = self.dropout(pooled_output)

            if task == 'sst-2':
                logits = self.sst_classifier(pooled_output)
            elif task == 'mnli':
                logits = self.mnli_classifier(pooled_output)

            outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here

            if labels is not None:
                if self.num_labels == 1:
                    #  We are doing regression
                    loss_fct = MSELoss()
                    loss = loss_fct(logits.view(-1), labels.view(-1))
                else:
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
                outputs = (loss,) + outputs

            return outputs  # (loss), logits, (hidden_states), (attentions)
...