Память и время в тензорных операциях python - PullRequest
0 голосов
/ 29 января 2019

Цель Моя цель - вычислить тензор, задаваемый формулой, которую вы можете увидеть ниже.Индексы i, j, k, l варьируются от 0 до 40, а p, m, x от 0 до 80.

The formula for the tensor

Тенсордотподход Это суммирование только сжимает 6 индексов огромного тензора.Я попытался сделать это с помощью тензорной точки, которая учитывает такие вычисления, но тогда моя проблема - память, даже если я делаю одну тензорную точку, а затем другую.(Я работаю в Colab, поэтому у меня доступно 12 ГБ ОЗУ)

Подход с использованием вложенных циклов Но есть некоторые дополнительные симметрии, управляющие матрицей B, т.е. единственные ненулевые элементы в B {ijpx} таковычто я + J = P + X.Поэтому я смог написать p и m как функцию от x (p = i + jx, m = k + lx), а затем я сделал 5 циклов именно для i, j, k, l, x, но затем с другой стороныпроблема заключается во времени, так как расчет занимает 136 секунд, и я хочу повторить это много раз.

Цели синхронизации в подходе с вложенными циклами Сокращение времени в десять раз было бы удовлетворительным, но если было бы возможно уменьшить его в 100 раз, этого было бы более чем достаточно.

Есть ли у вас какие-либо идеи по поводу проблемы с памятью или сокращения времени?Как вы обрабатываете такие суммирования с дополнительными ограничениями?

(Примечание. Матрица A является симметричной, и я до сих пор не использовал этот факт. Симметрий больше нет.)

Вот код для вложенного цикла:

for i in range (0,40):
  for j in range (0,40):
    for k in range (0,40):
      for l in range (0,40):
            Sum=0
            for x in range (0,80):
              p=i+j-x
              m=k+l-x
              if p>=0 and p<80 and m>=0 and m<80:
                Sum += A[p,m]*B[i,j,p,x]*B[k,l,m,x]
            T[i,j,k,l]= Sum

И код для подхода тензорной точки:

P=np.tensordot(A,B,axes=((0),(2)))
T=np.tensordot(P,B,axes=((0,3),(2,3)))

1 Ответ

0 голосов
/ 29 января 2019

Numba может быть вашим лучшим выбором здесь.Я собрал эту функцию на основе вашего кода.Я немного его изменил, чтобы избежать ненужных итераций и блока if:

import numpy as np
import numba as nb

@nb.njit(parallel=True)
def my_formula_nb(A, B):
    di, dj, dx, _ = B.shape
    T = np.zeros((di, dj, di, dj), dtype=A.dtype)
    for i in nb.prange (di):
        for j in nb.prange (dj):
            for k in nb.prange (di):
                for l in nb.prange (dj):
                    sum = 0
                    x_start = max(0, i + j - dx + 1, k + l - dx + 1)
                    x_end = min(dx, i + j + 1, k + l + 1)
                    for x in range(x_start, x_end):
                        p = i + j - x
                        m = k + l - x
                        sum += A[p, m] * B[i, j, p, x] * B[k, l, m, x]
                    T[i, j, k, l] = sum
    return T

Давайте посмотрим на это в действии:

import numpy as np

def make_problem(di, dj, dx):
    a = np.random.rand(dx, dx)
    a = a + a.T
    b = np.random.rand(di, dj, dx, dx)
    b_ind = np.indices(b.shape)
    b_mask = b_ind[0] + b_ind[1] != b_ind[2] + b_ind[3]
    b[b_mask] = 0
    return a, b

# Generate a problem
np.random.seed(100)
a, b = make_problem(15, 20, 25)
# Solve with Numba function
t1 = my_formula_nb(a, b)
# Solve with einsum
t2 = np.einsum('pm,ijpx,klmx->ijkl', a, b, b)
# Check result
print(np.allclose(t1, t2))
# True

# Benchmark (IPython)
%timeit np.einsum('pm,ijpx,klmx->ijkl', a, b, b)
# 4.5 s ± 39.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit my_formula_nb(a, b)
# 6.06 ms ± 20.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Как видите, решение Numbaпримерно на три порядка быстрее, и это не должно занимать больше памяти, чем необходимо.

...