Я пытаюсь использовать классификацию BertForTokenClassification от Transformers с предварительно обученной многоязычной моделью BERT для выполнения задачи маркировки последовательностей. Во время обучения все кажется прекрасным, и потери уменьшаются за эпоху, но когда я перевожу модель в режим оценки и подаю ей входные данные, логиты для каждого токена все те же.
Мой код выглядит так:
class PretrainedBert(nn.Module):
def __init__(self,
device,
pretrained_path='bert-base-multilingual-cased',
output_dim=5,
learning_rate=0.01,
eps=1e-8,
max_grad_norm=1.0,
model_name='bert_multilingual',
cased=True
):
super(PretrainedBert, self).__init__()
self.device = device
self.output_dim = output_dim
self.learning_rate = learning_rate
self.max_grad_norm = max_grad_norm
self.model_name = model_name
self.cased = cased
# Load pretrained bert
self.bert = BertForTokenClassification.from_pretrained(pretrained_path, num_labels=output_dim, output_attentions = False, output_hidden_states = False)
self.tokenizer = BertTokenizer.from_pretrained(pretrained_path, do_lower_case=not cased)
param_optimizer = list(self.bert.named_parameters())
no_decay = ['bias', 'gamma', 'beta']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
'weight_decay_rate': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
'weight_decay_rate': 0.0}]
self.optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=eps)
def forward(self, input_ids, attention_mask, labels=None):
output = self.bert(input_ids, attention_mask=attention_mask, labels=labels)
return output
def fit(self, train_loader, dev_loader, epochs=10):
"""
:param test_loader: torch.utils.data.DataLoader object
:param test_loader: torch.utils.data.DataLoader object
:param epochs: int
"""
scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=0,num_training_steps=len(train_loader)*epochs)
for epoch in range(epochs):
# Iterate over training data
epoch_time = time.time()
self.train()
epoch_loss = 0
num_batches = 0
for raw_text, x, y, idx in tqdm(train_loader):
self.zero_grad()
# Get gold labels and extend them to cover the wordpieces created by bert
y = self._extend_labels(raw_text, y)
y, _ = pad_packed_sequence(y, batch_first=True)
batches_len, seq_length = y.size()
y.to(self.device)
# Run input through bert and get logits
bert_tokenized = self.tokenizer.batch_encode_plus([' '.join(
text) for text in raw_text], max_length=seq_length, pad_to_max_length=True)
input_ids, token_type_ids, attention_masks = bert_tokenized[
'input_ids'], bert_tokenized['token_type_ids'], bert_tokenized['attention_mask']
loss, logits = self.forward(torch.LongTensor(
input_ids), torch.LongTensor(attention_masks), labels=y)
# Help prevent the "exploding gradients" problem.
torch.nn.utils.clip_grad_norm_(parameters=self.bert.parameters(), max_norm=self.max_grad_norm)
# update parameters and learning rate
loss.backward()
self.optimizer.step()
scheduler.step()
epoch_loss += loss.item()
num_batches += 1
epoch_time = time.time() - epoch_time
print()
print("Epoch {0} loss: {1:.3f}".format(epoch + 1,
epoch_loss / num_batches))
print("Dev")
binary_f1, propor_f1 = self.evaluate(dev_loader)
def predict(self, test_loader):
"""
:param test_loader: torch.utils.data.DataLoader object with
batch_size=1
"""
self.eval()
predictions, golds, sents = [], [], []
for raw_text, x, y, idx in tqdm(test_loader):
# Run input through bert and get logits
bert_tokenized = self.tokenizer.encode_plus(' '.join(raw_text[0]))
input_ids, token_type_ids, attention_masks = bert_tokenized[
'input_ids'], bert_tokenized['token_type_ids'], bert_tokenized['attention_mask']
with torch.no_grad():
logits = self.forward(torch.LongTensor(
input_ids).unsqueeze(0), torch.LongTensor(attention_masks).unsqueeze(0))
# remove batch dim and [CLS]+[SEP] tokens from logits
logits = logits[0].squeeze(0)[1:-1:]
# mean pool wordpiece rows of logits and argmax to get predictions
preds = self._logits_to_preds(logits, y[0], raw_text[0])
predictions.append(preds)
golds.append(y[0])
sents.append(raw_text[0])
return predictions, golds, sents
def evaluate(self, test_loader):
"""
Returns the binary and proportional F1 scores of the model on the examples passed via test_loader.
:param test_loader: torch.utils.data.DataLoader object with
batch_size=1
"""
preds, golds, sents = self.predict(test_loader)
flat_preds = [int(i) for l in preds for i in l]
flat_golds = [int(i) for l in golds for i in l]
analysis = get_analysis(sents, preds, golds)
binary_f1 = binary_analysis(analysis)
propor_f1 = proportional_analysis(flat_golds, flat_preds)
return binary_f1, propor_f1
def _extend_labels(self, text, labels):
extended_labels = []
for idx, (text_seq, label_seq) in enumerate(zip(text, labels)):
extended_seq_labels = []
for word, label in zip(text_seq, label_seq):
n_subwords = len(self.tokenizer.tokenize(word))
extended_seq_labels.extend([label.item()] * n_subwords)
extended_labels.append(torch.LongTensor(extended_seq_labels))
extended_labels = pack_sequence(extended_labels, enforce_sorted=False)
return extended_labels
def _logits_to_preds(self, logits, y, raw_text):
preds = torch.zeros(len(y))
head = 0
for idx, word in enumerate(raw_text):
n_subwords = len(self.tokenizer.tokenize(word))
if n_subwords > 1:
preds[idx] = torch.argmax(torch.mean(
logits[head:head+n_subwords, :], dim=0))
else:
preds[idx] = torch.argmax(logits[head, :])
head = head + n_subwords
return preds
Пример ввода:
raw_text = [['Donkey', 'Kong', 'Country', ':']]
y = [tensor([0, 0, 0, 0])]
Логиты, полученные при выполнении этих строк
self.eval()
bert_tokenized = self.tokenizer.encode_plus(' '.join(raw_text[0]))
input_ids, token_type_ids, attention_masks = bert_tokenized['input_ids'], bert_tokenized['token_type_ids'], bert_tokenized['attention_mask']
with torch.no_grad():
logits = self.forward(torch.LongTensor(input_ids).unsqueeze(0), torch.LongTensor(attention_masks).unsqueeze(0))
Результирующие логиты :
logits = tensor([[-3.2811, 1.1715, 1.2381, 0.5201, 1.0921],
[-3.2813, 1.1715, 1.2382, 0.5201, 1.0922],
[-3.2815, 1.1716, 1.2383, 0.5202, 1.0923],
[-3.2814, 1.1716, 1.2383, 0.5202, 1.0922],
[-3.2811, 1.1715, 1.2381, 0.5201, 1.0921]])
Есть идеи, что я делаю не так? Спасибо за любую помощь заранее.