Для меня это похоже на случай программирования груза.
Обратите внимание, что ваш класс Model
не использует self
в forward
, поэтому он фактически является "обычным" (не метод), и model
полностью без сохранения состояния.Самое простое исправление в вашем коде - сделать optimizer
осведомленным о w
и b
, создав его как optimizer = torch.optim.SGD([w, b], lr=0.01)
.Я также переписываю model
как функцию
import torch
import torch.nn as nn
# torch.autograd.Variable is roughly equivalent to requires_grad=True
# and is deprecated in PyTorch 1.0
# your code gives not reason to have `requires_grad=True` on `x_data`
x_data = torch.tensor( [ [1.0, 2.0], [2.0, 3.0], [3.0, 4.0] ])
y_data = torch.tensor( [ [2.0], [4.0], [6.0] ] )
w = torch.randn( 2, 1, requires_grad=True )
b = torch.randn( 1, 1, requires_grad=True )
def model(x2, w2, b2):
return x2 @ w2 + b2
criterion = torch.nn.MSELoss( size_average=False )
optimizer = torch.optim.SGD([w, b], lr=0.01 )
for epoch in range(10) :
y_pred = model( x_data,w,b )
loss = criterion( y_pred, y_data )
print( epoch, loss.data.item() )
optimizer.zero_grad()
loss.backward()
optimizer.step()
При этом nn.Linear
создан для упрощения этой процедуры.Он автоматически создает эквиваленты w
и b
, называемые self.weight
и self.bias
соответственно.Кроме того, self.__call__(x)
эквивалентно определению форварда вашего Model
в том смысле, что он возвращает self.weight @ x + self.bias
.Другими словами, вы также можете использовать альтернативный код
import torch
import torch.nn as nn
x_data = torch.tensor( [ [1.0, 2.0], [2.0, 3.0], [3.0, 4.0] ] )
y_data = torch.tensor( [ [2.0], [4.0], [6.0] ] )
model = nn.Linear(2, 1)
criterion = torch.nn.MSELoss( size_average=False )
optimizer = torch.optim.SGD(model.parameters(), lr=0.01 )
for epoch in range(10) :
y_pred = model(x_data)
loss = criterion( y_pred, y_data )
print( epoch, loss.data.item() )
optimizer.zero_grad()
loss.backward()
optimizer.step()
, где model.parameters()
можно использовать для перечисления параметров модели (эквивалентно списку, созданному вручную [w, b]
выше).Для доступа к вашим параметрам (загрузка, сохранение, печать и т. Д.) Используйте model.weight
и model.bias
.