как насчет попробовать это.просто я просто переписал функцию pyTorch в стиле Chainer.
import cupy
def clip_grad_norm(model, max_norm, norm_type=2):
params = list( filter(lambda p : p.grad is not None , model.params()) )
max_norm = float(max_norm)
norm_type = float(norm_type)
total_norm = 0.0
for p in params:
g = p.grad
norm = cupy.linalg.norm(g)
total_norm += norm**(norm_type)
total_norm = total_norm **(1/norm_type)
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in params:
g = p.grad
p.grad = g * clip_coef