Как очистить память Cuda в PyTorch - PullRequest
1 голос
/ 24 марта 2019

Я пытаюсь получить вывод нейронной сети, которую я уже обучил.На входе изображение размером 300х300.Я использую размер пакета 1, но я все еще получаю ошибку CUDA error: out of memory после того, как я успешно получил вывод для 25 изображений.

Я искал некоторые решения онлайн и наткнулся на torch.cuda.empty_cache().Но это все еще не решает проблему.

Это код, который я использую.

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

train_x = torch.tensor(train_x, dtype=torch.float32).view(-1, 1, 300, 300)
train_x = train_x.to(device)
dataloader = torch.utils.data.DataLoader(train_x, batch_size=1, shuffle=False)

right = []
for i, left in enumerate(dataloader):
    print(i)
    temp = model(left).view(-1, 1, 300, 300)
    right.append(temp.to('cpu'))
    del temp
    torch.cuda.empty_cache()

Этот for loop запускается 25 раз каждый раз, прежде чем выдавать ошибку памяти.

Каждый раз я отправляю в сеть новое изображение для вычислений.Поэтому мне не нужно сохранять результаты предыдущих вычислений в графическом процессоре после каждой итерации цикла.Есть ли способ достичь этого?

Любая помощь будет оценена.Спасибо.

1 Ответ

1 голос
/ 25 марта 2019

Я понял, где я иду не так.Я публикую решение как ответ для тех, кто может столкнуться с той же проблемой.

По сути, PyTorch делает то, что он создает вычислительный граф всякий раз, когда я передаю данные через свою сеть, и сохраняет вычисления впамять графического процессора, на случай, если я хочу рассчитать градиент во время обратного распространения.Но так как я хотел только выполнить прямое распространение, мне просто нужно было указать torch.no_grad() для моей модели.

Таким образом, цикл for в моем коде мог быть переписан как:

for i, left in enumerate(dataloader):
    print(i)
    with torch.no_grad():
        temp = model(left).view(-1, 1, 300, 300)
    right.append(temp.to('cpu'))
    del temp
    torch.cuda.empty_cache()

Указание no_grad() для моей модели сообщает PyTorch, что я не хочу сохранять какие-либо предыдущие вычисления, освобождая тем самым пространство моего графического процессора.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...