Есть ли лучший / более быстрый способ реализовать следующий фрагмент кода в Pytorch, избегая при этом цикла и сохраняя вычислительный граф без изменений?
def cumulative_max(X, dim=-1):
out = X.clone()
if dim < 0:
dim += X.dim()
leading_indices = (slice(None), ) * dim
n_iters = X.size(dim)
for idx in range(1, n_iters):
out[leading_indices + (idx, )] = torch.max(out[leading_indices + (idx - 1, )], X[leading_indices + (idx, )])
return out