Обратите внимание, что в отличие от других вопросов, речь идет не о какой-либо структуре RNN. Я хочу создать модель, которая имеет изменяющиеся градиенты и будет выглядеть ниже. Точки останова поставляются вручную.
Модель, которую я создал, выглядит следующим образом:
class Trend(nn.Module):
"""
Broken Trend model, with breakpoints as defined by user.
"""
def __init__(self, breakpoints):
super().__init__()
self.bpoints = breakpoints[None, :]
self.init_layer = nn.Linear(1,1) # first linear bit
# extract gradient and bias
w = self.init_layer.weight
b = self.init_layer.bias
self.params = [[w,b]] # save it to buffer
if len(breakpoints>0):
# create deltas which is how the gradient will change
deltas = torch.randn(len(breakpoints)) / len(breakpoints) # initialisation
self.deltas = nn.Parameter(deltas) # make it a parameter
for d, x1 in zip(self.deltas, breakpoints):
y1 = w *x1 + b # find the endpoint of line segment (x1, y1)
w = w + d # add on the delta to gradient
b = y1 - w * x1 # find new bias of line segment
self.params.append([w,b]) # add to buffer
# create buffer
self.wb = torch.zeros(len(self.params), len(self.params[0]))
def __copy2array(self):
"""
Saves parameters into wb
"""
for i in range(self.wb.shape[0]):
for j in range(self.wb.shape[1]):
self.wb[i,j] = self.params[i][j]
def forward(self, x):
# get the line segment area (x_sec) for each x
x_sec = x >= self.bpoints
x_sec = x_sec.sum(1)
self.__copy2array() # copy across parameters into matrix
# get final prediction y = mx +b for relevant section
return x*self.wb[x_sec][:,:1] + self.wb[x_sec][:,1:]
Однако, когда я пытаюсь обучить ее, я получаю ошибку RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time
.
Я получил вышеуказанный график, выполнив:
time = torch.arange(700).float()[:,None]
y_pred = model(time)
plt.plot(time, y_pred.detach().numpy())
plt.show()
Итак, мы знаем, что прямой проход работает, как и ожидалось. Однако обратный проход не совсем работает. Мне было интересно, что мне нужно изменить, чтобы заставить его работать.
Если вам интересно, почему используется __copy2array
, когда я попытался использовать torch.Tensor(self.params)
, он уничтожил градиенты в этих параметрах. Заранее спасибо.