Маркировка последовательностей с помощью BertForTokenClassification из трансформаторов - PullRequest
0 голосов
/ 07 мая 2020

Я пытаюсь использовать классификацию BertForTokenClassification от Transformers с предварительно обученной многоязычной моделью BERT для выполнения задачи маркировки последовательностей. Во время обучения все кажется прекрасным, и потери уменьшаются за эпоху, но когда я перевожу модель в режим оценки и подаю ей входные данные, логиты для каждого токена все те же.

Мой код выглядит так:

class PretrainedBert(nn.Module):

    def __init__(self,
                 device,
                 pretrained_path='bert-base-multilingual-cased',
                 output_dim=5,
                 learning_rate=0.01,
                 eps=1e-8,
                 max_grad_norm=1.0,
                 model_name='bert_multilingual',
                 cased=True
                 ):
        super(PretrainedBert, self).__init__()

        self.device = device
        self.output_dim = output_dim
        self.learning_rate = learning_rate
        self.max_grad_norm = max_grad_norm
        self.model_name = model_name
        self.cased = cased

        # Load pretrained bert
        self.bert = BertForTokenClassification.from_pretrained(pretrained_path, num_labels=output_dim, output_attentions = False, output_hidden_states = False)
        self.tokenizer = BertTokenizer.from_pretrained(pretrained_path, do_lower_case=not cased)

        param_optimizer = list(self.bert.named_parameters())
        no_decay = ['bias', 'gamma', 'beta']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
             'weight_decay_rate': 0.01},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
             'weight_decay_rate': 0.0}]
        self.optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=eps)


    def forward(self, input_ids, attention_mask, labels=None):
        output = self.bert(input_ids, attention_mask=attention_mask, labels=labels)        
        return output


    def fit(self, train_loader, dev_loader, epochs=10):
        """
        :param test_loader: torch.utils.data.DataLoader object
        :param test_loader: torch.utils.data.DataLoader object
        :param epochs: int
        """
        scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=0,num_training_steps=len(train_loader)*epochs)

        for epoch in range(epochs):
            # Iterate over training data
            epoch_time = time.time()
            self.train()
            epoch_loss = 0
            num_batches = 0

            for raw_text, x, y, idx in tqdm(train_loader):                
                self.zero_grad()

                # Get gold labels and extend them to cover the wordpieces created by bert
                y = self._extend_labels(raw_text, y)
                y, _ = pad_packed_sequence(y, batch_first=True)
                batches_len, seq_length = y.size()
                y.to(self.device)

                # Run input through bert and get logits
                bert_tokenized = self.tokenizer.batch_encode_plus([' '.join(
                    text) for text in raw_text], max_length=seq_length, pad_to_max_length=True)

                input_ids, token_type_ids, attention_masks = bert_tokenized[
                    'input_ids'], bert_tokenized['token_type_ids'], bert_tokenized['attention_mask']

                loss, logits = self.forward(torch.LongTensor(
                    input_ids), torch.LongTensor(attention_masks), labels=y)    

                # Help prevent the "exploding gradients" problem.
                torch.nn.utils.clip_grad_norm_(parameters=self.bert.parameters(), max_norm=self.max_grad_norm)

                # update parameters and learning rate
                loss.backward()
                self.optimizer.step()
                scheduler.step()

                epoch_loss += loss.item()
                num_batches += 1

            epoch_time = time.time() - epoch_time

            print()
            print("Epoch {0} loss: {1:.3f}".format(epoch + 1,
                                                   epoch_loss / num_batches))

            print("Dev")
            binary_f1, propor_f1 = self.evaluate(dev_loader)


    def predict(self, test_loader):
        """
        :param test_loader: torch.utils.data.DataLoader object with
                            batch_size=1
        """
        self.eval()
        predictions, golds, sents = [], [], []

        for raw_text, x, y, idx in tqdm(test_loader):     
            # Run input through bert and get logits
            bert_tokenized = self.tokenizer.encode_plus(' '.join(raw_text[0]))

            input_ids, token_type_ids, attention_masks = bert_tokenized[
                'input_ids'], bert_tokenized['token_type_ids'], bert_tokenized['attention_mask']

            with torch.no_grad():
                logits = self.forward(torch.LongTensor(
                    input_ids).unsqueeze(0), torch.LongTensor(attention_masks).unsqueeze(0)) 

            # remove batch dim and [CLS]+[SEP] tokens from logits
            logits = logits[0].squeeze(0)[1:-1:]

            # mean pool wordpiece rows of logits and argmax to get predictions 
            preds = self._logits_to_preds(logits, y[0], raw_text[0])

            predictions.append(preds)
            golds.append(y[0])
            sents.append(raw_text[0])

        return predictions, golds, sents


    def evaluate(self, test_loader):
        """
        Returns the binary and proportional F1 scores of the model on the examples passed via test_loader.
        :param test_loader: torch.utils.data.DataLoader object with
                            batch_size=1
        """
        preds, golds, sents = self.predict(test_loader)


        flat_preds = [int(i) for l in preds for i in l]
        flat_golds = [int(i) for l in golds for i in l]

        analysis = get_analysis(sents, preds, golds)
        binary_f1 = binary_analysis(analysis)
        propor_f1 = proportional_analysis(flat_golds, flat_preds)

        return binary_f1, propor_f1


    def _extend_labels(self, text, labels):
        extended_labels = []

        for idx, (text_seq, label_seq) in enumerate(zip(text, labels)):
            extended_seq_labels = []

            for word, label in zip(text_seq, label_seq):
                n_subwords = len(self.tokenizer.tokenize(word))
                extended_seq_labels.extend([label.item()] * n_subwords)

            extended_labels.append(torch.LongTensor(extended_seq_labels))

        extended_labels = pack_sequence(extended_labels, enforce_sorted=False)

        return extended_labels


    def _logits_to_preds(self, logits, y, raw_text):
        preds = torch.zeros(len(y))
        head = 0
        for idx, word in enumerate(raw_text):
            n_subwords = len(self.tokenizer.tokenize(word))

            if n_subwords > 1:
                preds[idx] = torch.argmax(torch.mean(
                    logits[head:head+n_subwords, :], dim=0))
            else:
                preds[idx] = torch.argmax(logits[head, :])

            head = head + n_subwords

        return preds

Пример ввода:

raw_text = [['Donkey', 'Kong', 'Country', ':']]
y =  [tensor([0, 0, 0, 0])]

Логиты, полученные при выполнении этих строк

self.eval()

bert_tokenized = self.tokenizer.encode_plus(' '.join(raw_text[0]))

input_ids, token_type_ids, attention_masks = bert_tokenized['input_ids'], bert_tokenized['token_type_ids'], bert_tokenized['attention_mask']

with torch.no_grad():
    logits = self.forward(torch.LongTensor(input_ids).unsqueeze(0), torch.LongTensor(attention_masks).unsqueeze(0)) 

Результирующие логиты :

logits = tensor([[-3.2811,  1.1715,  1.2381,  0.5201,  1.0921],
        [-3.2813,  1.1715,  1.2382,  0.5201,  1.0922],
        [-3.2815,  1.1716,  1.2383,  0.5202,  1.0923],
        [-3.2814,  1.1716,  1.2383,  0.5202,  1.0922],
        [-3.2811,  1.1715,  1.2381,  0.5201,  1.0921]])

Есть идеи, что я делаю не так? Спасибо за любую помощь заранее.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...