У меня есть функция с категорией прогнозирования на основе данных поезда. Наряду с этим я хочу вычислить оценку предсказанной категории, т. Е. Почему она относится к определенной категории, такой как политика, новости мира или спорт и т. Д.
Я копирую функцию для идеи
import spacy
nlp = spacy.load('en')
def predict_category(model, head, desc):
model.eval()
head = head.lower()
desc = desc.lower()
tokenized_head = [tok.text for tok in nlp.tokenizer(head)]
tokenized_desc = [tok.text for tok in nlp.tokenizer(desc)]
indexed_head = [TEXT.vocab.stoi[t] for t in tokenized_head]
indexed_desc = [TEXT.vocab.stoi[t] for t in tokenized_desc]
tensor_head = torch.LongTensor(indexed_head).to(device)
tensor_desc = torch.LongTensor(indexed_desc).to(device)
tensor_head = tensor_head.unsqueeze(1)
tensor_desc = tensor_desc.unsqueeze(1)
prediction = model(tensor_head, tensor_desc)
max_pred = prediction.argmax(dim=1)
return max_pred.item()
pred = predict_category(model, "Volkswagen Finance picks up 25 per cent stake in Kuwy Technology", "The partners will also offer finance, insurance and warranty products for Volkswagen group customers on Kuwy platform.")
print(f'Predicted category is: {pred} = {LABEL.vocab.itos[pred]}')