Построить производную функции с помощью PyTorch? - PullRequest
0 голосов
/ 29 ноября 2018

У меня есть этот код:

import torch
import matplotlib.pyplot as plt  
x=torch.linspace(-10, 10, 10, requires_grad=True)
y = torch.sum(x**2)
y.backward()
plt.plot(x.detach().numpy(), y.detach().numpy(), label='function')
plt.legend()

Но я получил эту ошибку:

ValueError: x and y must have same first dimension, but have shapes (10,) and (1,)

Ответы [ 2 ]

0 голосов
/ 30 ноября 2018

Я думаю, что главная проблема в том, что ваши размеры не совпадают.Почему вы не хотите использовать torch.sum?

Это должно работать для вас:

# %matplotlib inline added this line only for jupiter notebook
import torch
import matplotlib.pyplot as plt  
x = torch.linspace(-10, 10, 10, requires_grad=True)

y = x**2      # removed the sum to stay with the same dimensions
y.backward(x) # handing over the parameter x, as y isn't a scalar anymore
# your function
plt.plot(x.detach().numpy(), x.detach().numpy(), label='x**2')
# gradients
plt.plot(x.detach().numpy(), x.grad.detach().numpy(), label='grad')
plt.legend()

График вывода:

enter image description here

Вы получите более красивую картинку, хотя с большим количеством шагов я также немного изменил интервал на torch.linspace(-2.5, 2.5, 50, requires_grad=True):

enter image description here

Редактировать относительно комментария:

Эта версия отображает градиенты с torch.sum в комплекте:

# %matplotlib inline added this line only for jupiter notebook
import torch
import matplotlib.pyplot as plt  
x = torch.linspace(-10, 10, 10, requires_grad=True)

y = torch.sum(x**2) 
y.backward() 
print(x.grad)
plt.plot(x.detach().numpy(), x.grad.detach().numpy(), label='grad')
plt.legend()

Вывод:

tensor([-20.0000, -15.5556, -11.1111,  -6.6667,  -2.2222,   2.2222,
      6.6667,  11.1111,  15.5556,  20.0000])

Сюжет:

enter image description here

0 голосов
/ 30 ноября 2018

Я предполагаю, что вы хотите построить график производной от x**2.

Затем вам нужно построить график между x и x.grad НЕ x и y т.е.

plt.plot(x.detach().numpy(), x.grad.detach().numpy(), label='function').

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...