Non LSTM: Попытка перебрать график во второй раз, но буферы уже освобождены - PullRequest
0 голосов
/ 25 октября 2019

Обратите внимание, что в отличие от других вопросов, речь идет не о какой-либо структуре RNN. Я хочу создать модель, которая имеет изменяющиеся градиенты и будет выглядеть ниже. Точки останова поставляются вручную. enter image description here

Модель, которую я создал, выглядит следующим образом:

class Trend(nn.Module):
    """
    Broken Trend model, with breakpoints as defined by user.
    """
    def __init__(self, breakpoints):
        super().__init__()
        self.bpoints = breakpoints[None, :]
        self.init_layer = nn.Linear(1,1) # first linear bit
        # extract gradient and bias
        w = self.init_layer.weight
        b = self.init_layer.bias
        self.params = [[w,b]] # save it to buffer

        if len(breakpoints>0):
            # create deltas which is how the gradient will change
            deltas = torch.randn(len(breakpoints)) / len(breakpoints) # initialisation
            self.deltas = nn.Parameter(deltas) # make it a parameter

            for d, x1 in zip(self.deltas, breakpoints):
                y1 = w *x1 + b # find the endpoint of line segment (x1, y1)
                w = w + d # add on the delta to gradient 
                b = y1 - w * x1 # find new bias of line segment 
                self.params.append([w,b]) # add to buffer

        # create buffer
        self.wb = torch.zeros(len(self.params), len(self.params[0]))

    def __copy2array(self):
        """
        Saves parameters into wb
        """
        for i in range(self.wb.shape[0]):
            for j in range(self.wb.shape[1]):
                self.wb[i,j] = self.params[i][j]

    def forward(self, x):
        # get the line segment area (x_sec) for each x
        x_sec = x >= self.bpoints
        x_sec = x_sec.sum(1)
        self.__copy2array() # copy across parameters into matrix

        # get final prediction y = mx +b for relevant section
        return x*self.wb[x_sec][:,:1] + self.wb[x_sec][:,1:]

Однако, когда я пытаюсь обучить ее, я получаю ошибку RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

Я получил вышеуказанный график, выполнив:

time = torch.arange(700).float()[:,None]
y_pred = model(time)
plt.plot(time, y_pred.detach().numpy())
plt.show()

Итак, мы знаем, что прямой проход работает, как и ожидалось. Однако обратный проход не совсем работает. Мне было интересно, что мне нужно изменить, чтобы заставить его работать.

Если вам интересно, почему используется __copy2array, когда я попытался использовать torch.Tensor(self.params), он уничтожил градиенты в этих параметрах. Заранее спасибо.

1 Ответ

1 голос
/ 25 октября 2019

Поскольку ваш ответ не содержит полного кода, судить сложно, но я рекомендую попробовать то, что написано в сообщении об ошибке: замените .backward() на .backward( retain_graph=True). Это означает, что градиент не удаляется после обновления.

...