Пользовательская функция потери расстояния в Pytorch? - PullRequest
0 голосов
/ 29 октября 2019

Я хочу реализовать следующую функцию потери расстояния в pytorch. Я следил за этой https://discuss.pytorch.org/t/custom-loss-functions/29387/4 веткой форума pytorch

np.linalg.norm(output - target)
# where output.shape = [1, 2] and target.shape = [1, 2]

Итак, я реализовал функцию потерь, такую ​​как

def my_loss(output, target):    
    loss = torch.tensor(np.linalg.norm(output.detach().numpy() - target.detach().numpy()))
    return loss

, с помощью этой функции потерь, вызывая обратный вызовошибка времени выполнения

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Весь мой код выглядит следующим образом

model = nn.Linear(2, 2)

x = torch.randn(1, 2)
target = torch.randn(1, 2)
output = model(x)

loss = my_loss(output, target)
loss.backward()   <----- Error here

print(model.weight.grad)

PS: я знаю о попарной потере pytorch, но из-за некоторых ее ограничений я должен ее реализоватьсебя.

Следуя исходному коду Pytorch, я попробовал следующее:

class my_function(torch.nn.Module): # forgot to define backward()
    def forward(self, output, target):

        loss = torch.tensor(np.linalg.norm(output.detach().numpy() - target.detach().numpy()))
        return loss

model = nn.Linear(2, 2)
x = torch.randn(1, 2)
target = torch.randn(1, 2)
output = model(x)

criterion = my_function()

loss = criterion(output, target)


loss.backward()
print(model.weight.grad)

И я получаю ошибку времени выполнения

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Как правильно реализовать функцию потерь

1 Ответ

5 голосов
/ 29 октября 2019

Это происходит потому, что в функции потерь вы отделяете тензоры. Вы должны были отсоединиться, потому что хотели использовать np.linalg.norm. Это нарушает график, и вы получаете ошибку, что у тензоров нет града fn.

Вы можете заменить

loss = torch.tensor(np.linalg.norm(output.detach().numpy() - target.detach().numpy()))

на операции резака как

loss = torch.norm(output-target)

Это должно работать нормально.

...