Я реализую RL в PyTorch, и мне пришлось написать собственную функцию mse_loss (которую я нашел в Stackoverflow;)).Функция потерь:
def mse_loss(input_, target_):
return torch.sum(
(input_ - target_) * (input_ - target_)) / input_.data.nelement()
Теперь, в моем цикле тренировки, первый вход выглядит примерно так:
tensor([-1.7610e+10]), tensor([-6.5097e+10])
С этим вводом я получу ошибку:
Unable to get repr for <class 'torch.Tensor'>
Вычисление a = (input_ - target_)
работает нормально, в то время как b = a * a
соответственно b = torch.pow(a, 2)
завершится с ошибкой, упомянутой выше.
Кто-нибудь знает, как исправить это?
Большое спасибо!
Обновление : Я только что попытался использовать torch.nn.functional.mse_loss
, что приведет к той же ошибке ..