Как мне изменить эту функцию new_cdist () , чтобы устранить Ошибка времени выполнения GPU вне памяти ?
В частности, замена torch.cdist()
чем-то более легким в памяти.
Примечание. Следующий фрагмент кода относится к новому типу уравнения обратного распространения и адаптивной скорости обучения .
См. здесь для получения дополнительной информации о том, как работает существующий код
def new_cdist(p, eta):
class cdist(torch.autograd.Function):
@staticmethod
def forward(ctx, W, X):
ctx.save_for_backward(W, X)
out = -torch.cdist(W, X, p)
return out
@staticmethod
def backward(ctx, grad_output):
W, X = ctx.saved_tensors
grad_W = grad_X = None
if ctx.needs_input_grad[0]:
_temp1 = torch.unsqueeze(X, 2).expand(X.shape[0], X.shape[1], W.shape[0]).permute(1, 0, 2)
_temp2 = torch.unsqueeze(W.transpose(0, 1), 1)
_temp = torch.cdist(_temp1, _temp2, p).squeeze().transpose(0, 1)
grad_W = torch.matmul(grad_output, _temp)
# print('before norm: ', torch.norm(grad_W))
grad_W = eta * np.sqrt(grad_W.numel()) / torch.norm(grad_W) * grad_W
print('after norm: ', torch.norm(grad_W))
if ctx.needs_input_grad[1]:
_temp1 = torch.unsqueeze(W, 2).expand(W.shape[0], W.shape[1], X.shape[0]).permute(1, 0, 2)
_temp2 = torch.unsqueeze(X.transpose(0, 1), 1)
_temp = torch.cdist(_temp1, _temp2, p).squeeze().transpose(0, 1)
_temp = torch.nn.functional.hardtanh(_temp, min_val=-1., max_val=1.)
grad_X = torch.matmul(grad_output.transpose(0, 1), _temp)
return grad_W, grad_X
return cdist().apply