Таможенная потеря перекрестной энтропии в pytorch - PullRequest
3 голосов
/ 19 июня 2019

Я сделал пользовательскую реализацию функции кросс-энтропийной потери Pytorch (поскольку мне нужно больше гибкости, чтобы представить ее позже).Модель, которую я собираюсь обучить, будет нуждаться в значительном времени для обучения, и имеющиеся ресурсы не могут быть использованы для простого тестирования правильности реализации функции.Я реализовал векторизованную реализацию, поскольку она будет выполняться быстрее.

Ниже приведен мой код для этого:

def custom_cross(my_pred,true,batch_size=BATCH_SIZE):
    loss= -torch.mean(torch.sum(true.view(batch_size, -1) * torch.log(my_pred.view(batch_size, -1)), dim=1))
    return loss

Буду очень признателен, если вы сможете предложить более оптимизированную реализациюто же самое или если я делаю ошибку в настоящем.Модель будет использовать Nvidia Tesla K-80 для тренировки.

1 Ответ

2 голосов
/ 20 июня 2019

Если вам нужна только кросс-энтропия , вы можете воспользоваться тем преимуществом, который определил PyTorch.

import torch.nn.functional as F
loss_func = F.cross_entropy

предложить более оптимизированную реализацию

PyTorch имеет F. функций потерь, но вы можете легко написать свою собственную, используя простой Python. PyTorch автоматически создаст быстрый графический процессор или векторизованный код процессора для вашей функции.

Итак, вы можете проверить оригинальную реализацию PyTorch, но я думаю, что это:

def log_softmax(x):
    return x - x.exp().sum(-1).log().unsqueeze(-1)

И здесь - исходная реализация кросс-энтропийной потери, теперь вы можете просто изменить:

nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)

К чему-то, что вам нужно, и у вас это есть.

...