Большинство решений сложных проблем включает разбиение на более мелкие проблемы и их решение по отдельности.
Мы можем переписать эту задачу в 2 части:
Генерация последовательностей операций.
Например, для F(A, B, 2, 1) -> AAB+ABA+BAA
(неявное умножение матриц)
Вычисление (эффективно) этих операций. Мы можем заметить, что некоторые вычисления будут выполнены несколько раз. Например, когда мы вычисляем AAB+ABA+BAA
, мы можем сгруппировать все умножения AA
и сохранить результат для последующего использования при необходимости.
Для генерации последовательностей мы можем использовать more_itertools
distinct_permutations
функция . Кодируя A
в 0
и B
в 1
, он генерирует последовательности вычислений, которые мы хотим.
Чтобы выполнить вычисления для одной последовательности, мы должны воспользоваться преимуществами предыдущих вычислений. Мы можем использовать запоминание, чтобы запомнить предыдущие результаты и выполнить их только один раз.
# /!\ IMPORTANT: This initialization is wrong in some case, see the EDIT.
memo = {(0,): A,
(1,): B
} # should be initialized everytimes A, or B changes.
def matmul_perm(A, B, perm):
if perm in memo: # If previously computed, return result
return memo[perm]
mid = len(perm) // 2
memo[perm] = (matmul_perm(A, B, perm[:mid]) @
matmul_perm(A, B, perm[mid:])) # Split computation in 2 equal part and store result
return memo[perm]
Теперь мы можем определить нашу функцию:
def F(A, B, na, nb):
s = 0
for perm in distinct_permutations((0,)*na + (1,)*nb):
s += matmul_perm(A, B, perm)
return s
И, наконец, протестировать нашу программу:
A = np.random.randn(50, 50)
B = np.random.randn(50, 50)
memo = {(1,): B, (0,): A}
np.max(np.abs(F(A, B, 2, 2) - (A@A@B@B + A@B@A@B + B@A@A@B + A@B@B@A + B@A@B@A + B@B@A@A)))
>>> 1.8189894035458565e-12
На моем компьютере требуется около 6 сек c для вычисления F(A, B, 10, 10)
с A
и B
размером 50×50
и запиской fre sh. (и менее чем за секунду, чтобы пересчитать его во второй раз)
EDIT
Эта реализация повторяется вечно при вызове с na=nb=0
. Самое простое решение - изменить инициализацию memo
на
memo = {(0,): A,
(1,): B,
(): np.eye(*A.shape) # Empty product !
}