Как получить доступ к прогнозам модели классификации Pytorch? (БЕРТ) - PullRequest
0 голосов
/ 18 мая 2019

Я запускаю этот файл: https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/examples/run_classifier.py

Это код прогноза для одной входной партии:

  input_ids = input_ids.to(device)
  input_mask = input_mask.to(device)
  segment_ids = segment_ids.to(device)
  label_ids = label_ids.to(device)

  with torch.no_grad():
       logits = model(input_ids, segment_ids, input_mask, labels=None)

       loss_fct = CrossEntropyLoss()
       tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))

       eval_loss += tmp_eval_loss.mean().item()
       nb_eval_steps += 1
       if len(preds) == 0:
           preds.append(logits.detach().cpu().numpy())
       else:
           preds[0] = np.append(preds[0], logits.detach().cpu().numpy(), axis=0)

Задача представляет собой двоичную классификацию. Я хочу получить доступ к двоичному выходу.

Я пробовал это:

  curr_pred = logits.detach().cpu()

  if len(preds) == 0:
      preds.append(curr_pred.numpy())
  else:
      preds[0] = np.append(preds[0], curr_pred.numpy(), axis=0)

  probablities = curr_pred.softmax(1).numpy()[:, 1]

Но результаты кажутся странными. Поэтому я не уверен, что это правильный путь.

Моя гипотеза - я получаю выходные данные последнего слоя, поэтому после softmax это истинные вероятности (вектор dim 2 - вероятность для 1-го и вероятность для 2-го класса.)

1 Ответ

0 голосов
/ 25 мая 2019

После просмотра этой части кода run_classifier.py:

    # copied from the run_classifier.py code 
    eval_loss = eval_loss / nb_eval_steps
    preds = preds[0]
    if output_mode == "classification":
        preds = np.argmax(preds, axis=1)
    elif output_mode == "regression":
        preds = np.squeeze(preds)
    result = compute_metrics(task_name, preds, all_label_ids.numpy())

Вы просто скучаете:

    preds = preds[0]
    preds = np.argmax(preds, axis=1)

Тогда они просто используют преды для вычисления точности как:

    def simple_accuracy(preds, labels):
         return (preds == labels).mean()
...