Первоначально я разместил вопрос о том, как эффективно вычислить logsumexp, найденный здесь
Как эффективно вычислить logsumexp верхнего треугольника во вложенном цикле?
Ответ, который я принял, был
import Numpy as np
Wm = np.array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12]])
wx = np.array([1, 2, 3])
wy = np.array([4, 5, 6])
Wxy = np.array([[5, 6, 7],
[6, 7, 8],
[7, 8, 9]])
'''
np.triu_indices = ([0, 0, 1], [1, 2, 2])
Wxy[triu_inds] = [6, 7, 8]
np.logsumexp(Wxy[triu_inds]) = log(exp(6) + exp(7) + exp(8))
'''
for x in range(n-1):
wx = Wm[x, :]
for y in range(x+1, n):
wy = Wm[y, :]
Wxy = np.add.outer(wx, wy)
Wxy = Wxy[triu_inds]
W[x, y] = np.logsumexp(Wxy)
# solution here
W = np.logsumexp(
np.add.outer(Wm, Wm).swapaxes(1, 2)[(slice(None),)*2 + triu_inds],
axis=-1 # Perform summation over last axis.
)
W = np.triu(W, k=1)
Проблема в том, что это действительно медленно с большими матрицами, поскольку проблема быстро взрывается. Если размер Wm
равен m,n
, то объем необходимой памяти увеличивается на (m*n)**2 * 8
байт. Мне нужно запустить это на матрицах размером больше 1000x200
, но я получаю ошибки памяти, и это очень медленно.