Вам следует просто считать их недействительными, потому что вы пытаетесь предсказать правильный диапазон ответа из переменной text
. Все остальное должно быть недействительным. Это также способ, которым huggingface обрабатывает следующие прогнозы:
Мы могли бы гипотетически создать неверные прогнозы, например, предсказать, что речь идет о начале диапазона. Мы отбрасываем все недопустимые прогнозы.
Вы также должны заметить, что они используют более сложный метод , чтобы получить прогнозы для каждого вопроса (не спрашивайте меня, почему они показывают факел .argmax в их примере). Пожалуйста, посмотрите на пример ниже:
from transformers.data.processors.squad import SquadResult, SquadExample, SquadFeatures,SquadV2Processor, squad_convert_examples_to_features
from transformers.data.metrics.squad_metrics import compute_predictions_logits, squad_evaluate
###
#your example code
###
outputs = model(**input_dict)
def to_list(tensor):
return tensor.detach().cpu().tolist()
output = [to_list(output[0]) for output in outputs]
start_logits, end_logits = output
all_results = []
all_results.append(SquadResult(1000000000, start_logits, end_logits))
#this is the answers section from the evaluation dataset
answers = [{'text':'not restored by the communist authorities', 'answer_start':77}, {'text':'were not restored', 'answer_start':72}, {'text':'not restored by the communist authorities after the war', 'answer_start':77}]
examples = [SquadExample('0', question, text, 'not restored by the communist authorities', 75, 'Warsaw', answers,False)]
#this does basically the same as tokenizer.encode_plus() but stores them in a SquadFeatures Object and splits if neccessary
features = squad_convert_examples_to_features(examples, tokenizer, 512, 100, 64, True)
predictions = compute_predictions_logits(
examples,
features,
all_results,
20,
30,
True,
'pred.file',
'nbest_file',
'null_log_odds_file',
False,
True,
0.0,
tokenizer
)
result = squad_evaluate(examples, predictions)
print(predictions)
for x in result.items():
print(x)
Вывод:
OrderedDict([('0', 'communist authorities after the war')])
('exact', 0.0)
('f1', 72.72727272727273)
('total', 1)
('HasAns_exact', 0.0)
('HasAns_f1', 72.72727272727273)
('HasAns_total', 1)
('best_exact', 0.0)
('best_exact_thresh', 0.0)
('best_f1', 72.72727272727273)
('best_f1_thresh', 0.0)