Как вычислить сумму квадратов внешних произведений двух матриц минус общая матрица в Matlab? - PullRequest
0 голосов
/ 12 ноября 2018

Предположим, есть три n * n матриц X, Y, S. Как быстро вычислить следующие скаляры b

for i = 1:n
  b = b  + sum(sum((X(i,:)' * Y(i,:) - S).^2));
end

Стоимость вычисления O (n ^ 3). Существует быстрый способ вычисления внешнего произведения двух матриц . В частности, матрица C

for i = 1:n
  C = C + X(i,:)' * Y(i,:);
end

можно вычислить без цикла for C = A.'*B, который равен только O (n ^ 2) Существует ли более быстрый способ вычисления b?

Ответы [ 2 ]

0 голосов
/ 12 ноября 2018

Вы можете использовать:

X2 = X.^2;
Y2 = Y.^2;
S2 = S.^2;
b = sum(sum(X2.' * Y2 - 2 * (X.' * Y ) .* S + n * S2));

Учитывая ваш пример

b=0;
for i = 1:n
   b = b  + sum(sum((X(i,:).' * Y(i,:) - S).^2));
end

Сначала мы можем вывести суммирование из цикла:

b=0;
for i = 1:n
  b = b  + (X(i,:).' * Y(i,:) - S).^2;
end
b=sum(b(:))

Зная, что мы можем написать (a - b)^2 как a^2 - 2*a*b + b^2

b=0;
for i = 1:n
  b = b  + (X(i,:).' * Y(i,:)).^2 - 2.* (X(i,:).' * Y(i,:)) .*S + S.^2;
end
b=sum(b(:))

И мы знаем, что (a * b) ^ 2 совпадает с a^2 * b^2:

X2 = X.^2;
Y2 = Y.^2;
S2 = S.^2;
b=0;
for i = 1:n
  b = b  + (X2(i,:).' * Y2(i,:)) - 2.* (X(i,:).' * Y(i,:)) .*S + S2;
end
b=sum(b(:))

Теперь мы можем вычислить каждый член отдельно:

 b = sum(sum(X2.' * Y2 - 2 * (X.' * Y ) .* S + n * S2));

Вот результат теста в Octave, который сравнивает мой метод и два других метода, предоставленных @AndrasDeak, и оригинальное решение на основе циклов для входных данных размером 500*500:

===rahnema1 (B)===
Elapsed time is 0.0984299 seconds.

===Andras Deak (B2)===
Elapsed time is 7.86407 seconds.

===Andras Deak (B3)===
Elapsed time is 2.99158 seconds.

===Loop solution===
Elapsed time is 2.20357 seconds


n=500;
X= rand(n);
Y= rand(n);
S= rand(n);

disp('===rahnema1 (B)===')
tic
    X2 = X.^2;
    Y2 = Y.^2;
    S2 = S.^2;
    b=sum(sum(X2.' * Y2 - 2 * (X.' * Y ) .* S + n * S2));
toc
disp('===Andras Deak (B2)===')
tic
    b2 = sum(reshape((permute(reshape(X, [n, 1, n]).*Y, [3,2,1]) - S).^2, 1, []));
toc
disp('===Andras Deak (B3)===')
tic
    b3 = sum(reshape((reshape(X, [n, 1, n]).*Y - reshape(S.', [1, n, n])).^2, 1, []));
toc
tic
    b=0;
    for i = 1:n
      b = b  + sum(sum((X(i,:)' * Y(i,:) - S).^2));
    end
toc
0 голосов
/ 12 ноября 2018

Вы, вероятно, не можете сэкономить время, но вы можете использовать векторизацию, чтобы избавиться от цикла и максимально использовать низкоуровневый код и кэширование. То, будет ли это на самом деле быстрее, зависит от ваших размеров, поэтому вам нужно провести несколько временных тестов, чтобы понять, стоит ли это того:

% dummy data
n = 3;
X = rand(n);
Y = rand(n);
S = rand(n);

% vectorize
b2 = sum(reshape((permute(reshape(X, [n, 1, n]).*Y, [3,2,1]) - S).^2, 1, []));

% check
b - b2 % close to machine epsilon i.e. zero

В результате мы вставляем новое одноэлементное измерение в один из массивов, заканчивая массивом размером [n, 1, n] против массива с [n, n], причем последнее неявно совпадает с [n, n, 1]. Первый перекрывающийся индекс соответствует i в вашем цикле, остальные два индекса соответствуют индексам матрицы диадического произведения, которое вы имеете для каждого i. Затем мы переставляем индексы, чтобы поставить индекс «i» последним, чтобы мы могли снова передать результат с S (неявного) размера [n, n, 1]. Затем мы имеем матрицу размером [n, n, n], где первые два индекса - это матричные индексы в вашем оригинале, а последний соответствует i. Затем нам нужно просто взять квадрат и сложить каждое слагаемое (вместо того, чтобы суммировать дважды, я преобразовал массив в строку и суммировал один раз).

Небольшое изменение вышеуказанных транспонирований S вместо 3d-массива, которое может быть быстрее (опять же, вы должны рассчитать время):

b3 = sum(reshape((reshape(X, [n, 1, n]).*Y - reshape(S.', [1, n, n])).^2, 1, []));

С точки зрения производительности, reshape бесплатен (он только интерпретирует данные, но не копирует), но permute / transpose часто приводит к попаданию перфорации при копировании данных.

...