При выполнении холеского разложения PyTorch использует LAPACK для тензоров ЦП и MAGMA для тензоров CUDA. В коде PyTorch, используемом для вызова LAPACK , пакет просто повторяется, вызывая функцию LAPACK zpotrs_
для каждой матрицы отдельно. В коде PyTorch, используемом для вызова MAGMA , весь пакет обрабатывается с использованием magma_dpotrs_batched
MAGMA, что, вероятно, быстрее, чем итерация по каждой матрице в отдельности.
AFAIK нет способ указать MAGMA или LAPACK не вызывать исключения (хотя, если честно, я не эксперт по этим пакетам). Так как MAGMA может каким-то образом использовать пакеты, мы можем не захотеть просто использовать по умолчанию итеративный подход, поскольку мы потенциально теряем производительность, не выполняя пакетный режим cholesky.
Одно из возможных решений - сначала попробовать и выполнить пакетный режим. декомпозиция cholesky, если она завершится неудачно, мы могли бы выполнить декомпозицию cholesky для каждого элемента в пакете, установив записи, для которых не было NaN.
def cholesky_no_except(x, upper=False, force_iterative=False):
success = False
if not force_iterative:
try:
results = torch.cholesky(x, upper=upper)
success = True
except RuntimeError:
pass
if not success:
# fall back to operating on each element separately
results_list = []
x_batched = x.reshape(-1, x.shape[-2], x.shape[-1])
for batch_idx in range(x_batched.shape[0]):
try:
result = torch.cholesky(x_batched[batch_idx, :, :], upper=upper)
except RuntimeError:
# may want to only accept certain RuntimeErrors add a check here if that's the case
# on failure create a "nan" matrix
result = float('nan') + torch.empty(x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype)
results_list.append(result)
results = torch.cat(results_list, dim=0).reshape(*x.shape)
return results
Если вы ожидаете, что исключения будут распространены во время декомпозиции cholesky, вы можете использовать force_iterative=True
, чтобы пропустить начальный вызов, который пытается использовать пакетную версию, поскольку в этом случае эта функция, скорее всего, будет просто тратить время с первой попытки.