PyTorch: при использовании backward (), как я могу сохранить только часть графика? - PullRequest
0 голосов
/ 07 июня 2018

У меня есть вычислительный граф PyTorch, который состоит из подграфа, выполняющего некоторое вычисление, и результат этого вычисления (назовем его x) затем разветвляется на два других подграфа.Каждый из этих двух подграфов дает некоторые скалярные результаты (назовем их y1 и y2).Я хочу сделать обратный проход для каждого из этих двух результатов (то есть я хочу накапливать градиенты двух подграфов. Я не хочу выполнять фактический шаг оптимизации).

Теперь,так как память здесь является проблемой, я хотел бы выполнить операции в следующем порядке: во-первых, рассчитать x.Затем вычислите y1 и выполните y1.backward(), в то время как (и это ключевой момент) , сохраняя график, ведущий к x, но освобождая график от x до y1.Затем вычислите y2 и выполните y2.backward().

Другими словами, чтобы сэкономить память, не жертвуя слишком большой скоростью, я хочу сохранить x без необходимости пересчитывать ее, но я хочу отброситьвсе вычисления, начинающиеся с x до y1 после того, как они мне больше не нужны.

Проблема в том, что аргумент retain_graph функции backward() сохранит весь граф, ведущий к y1, тогда как мне нужно сохранить только часть графика, ведущую к x.

Вот пример того, что я в идеале хотел бы:

import torch

w = torch.tensor(1.0)
w.requires_grad_(True)

# sub-graph for calculating `x`
x = w+10

# sub-graph for calculating `y1`
x1 = x*x
y1 = x1*x1
y1.backward(retain_graph=x) # this would not work, since retain_graph is a boolean and can either retain the entire graph or free it.

# sub-graph for calculating `y2`
x2 = torch.sqrt(x)
y2 = x2/2
y2.backward()

Как это можно сделать

1 Ответ

0 голосов
/ 07 июня 2018

Аргумент retain_graph сохранит весь граф, а не только подграф.Однако мы можем использовать сборку мусора, чтобы освободить ненужные части графика.При удалении всех ссылок на подграф с x до y1 этот подграф будет освобожден:

import torch

w = torch.tensor(1.0)
w.requires_grad_(True)

# sub-graph for calculating `x`
x = w+10

# sub-graph for calculating `y1`
x1 = x*x
y1 = x1*x1
y1.backward(retain_graph=True) # all graph is retained

# remove unneeded parts of graph. Note that these parts will be freed from memory (even if they were on GPU), due to python's garbage collection 
y1 = None
x1 = None

# sub-graph for calculating `y2`
x2 = torch.sqrt(x)
y2 = x2/2
y2.backward()
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...