Pytorch. Оптимизация входных данных для модели: повторная попытка перевернуть график, но буферы уже освобождены - PullRequest
0 голосов
/ 07 ноября 2019

В настоящее время я пытаюсь оптимизировать значения входного тензора х для модели. Я хочу ограничить ввод только значениями в диапазоне [0,0; 1,0].

Существует не слишком много информации о том, как это сделать, если не работать со слоем как таковым.

Я создал минимальный рабочий пример ниже, который дает сообщение об ошибке в заголовке этого сообщения.

Волшебство происходит в функции optimize_x ()

Если я закомментирую строку: model.x = model.x.clamp(min=0.0, max=1.0) проблема исправлена, но тензор явно не зажат.

Я знаю, что мог бы просто установить retain_graph=True - но не ясно, является ли этоправильный путь, или если есть лучший способ достижения этой функциональности?

import torch
from torch.distributions import Uniform


class OptimizeInputModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torch.nn.Sequential(
                torch.nn.Linear(123, 1000),
                torch.nn.Dropout(0.4),
                torch.nn.ReLU(),
                torch.nn.Linear(1000, 100),
                torch.nn.Dropout(0.4),
                torch.nn.ReLU(),
                torch.nn.Linear(100, 1),
                torch.nn.Sigmoid(),
        )

        in_shape = (1, 123)
        self.x = torch.ones(in_shape) * 0.1
        self.x.requires_grad = True

    def forward(self) -> torch.Tensor:
        return self.model(self.x)

class MyLossFunc(torch.nn.Module):

    def forward(self, y: torch.Tensor) -> torch.Tensor:
        loss = torch.sum(-y)
        return loss


def optimize_x():
    model = OptimizeInputModel()
    optimizer = torch.optim.Adam([model.x], lr=1e-4)
    loss_fn = MyLossFunc()
    for epoch in range(50000):
        # Constrain X to have no values < 0
        model.x = model.x.clamp(min=0.0, max=1.0)
        y = model()
        loss = loss_fn(y)

        if epoch % 9 == 0:
            print(f'Epoch: {epoch}\t Loss: {loss}')

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


optimize_x()

Полное сообщение об ошибке: RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

1 Ответ

0 голосов
/ 13 ноября 2019

Для тех, кто в будущем может задать тот же вопрос.

Мое решение было сделать (обратите внимание на подчеркивание!):

model.x.data.clamp_(min=0.0, max=1.0)

вместо:

model.x = model.x.clamp(min=0.0, max=1.0)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...