Я написал функцию для этого. Два ключевых компонента: (1) использование retain_graph=True
для всех, кроме последнего вызова .backward()
и (2) сохранение оценок после каждого вызова .backward()
и восстановление их в конце перед .step()
ing.
def multi_step(losses, optms):
# optimizers each take a step, with `optms[i]`'s variables being
# optimized w.r.t. `losses[i]`.
grads = [None]*len(losses)
for i, (loss, optm) in enumerate(zip(losses, optms)):
retain_graph = i != (len(losses)-1)
optm.zero_grad()
loss.backward(retain_graph=retain_graph)
grads[i] = [
[
p.grad+0 for p in group['params']
] for group in optm.param_groups
]
for optm, grad in zip(optms, grads):
for p_group, g_group in zip(optm.param_groups, grad):
for p, g in zip(p_group['params'], g_group):
p.grad = g
optm.step()
В примере кода, указанного в вопросе, multi_step
будет использоваться следующим образом:
for i in range(n_iter):
shared_computation = foobar(x, y, z)
x_loss = f(x, y, z, shared_computation)
y_loss = g(x, y, z, shared_computation)
z_loss = h(x, y, z, shared_computation)
multi_step([x_loss, y_loss, z_loss], [x_opt, y_opt, z_opt])