Октава / Матлаб: эффективный расчет внутреннего продукта Фробениуса? - PullRequest
8 голосов
/ 07 ноября 2011

У меня есть две матрицы A и B, и я хочу получить:

trace(A*B)

Если я не ошибаюсь, это называется Внутренний продукт Фробениуса .

Меня беспокоит эффективность.Я просто боюсь, что этот прямой подход сначала выполнит все умножение (мои матрицы - тысячи строк / столбцов), и только затем отследит продукт, тогда как операция, которая мне действительно нужна, намного проще.Есть ли функция или синтаксис, чтобы сделать это эффективно?

Ответы [ 3 ]

5 голосов
/ 07 ноября 2011

Правильно ... суммирование поэлементных произведений будет быстрее:

n = 1000

A = randn(n);
B = randn(n);

tic
sum(sum(A .* B));
toc

tic
sum(diag(A * B'));
toc
Elapsed time is 0.010015 seconds.
Elapsed time is 0.130514 seconds.
2 голосов
/ 07 ноября 2011

sum(sum(A.*B)) избегает полного умножения матриц

1 голос
/ 27 мая 2013

Как насчет использования векторного умножения?

(A(:)')*B(:)

Проверка времени выполнения

Сравнение четырех вариантов с A и B размера 1000 на 1000:
1. векторное внутреннее произведение: A(:)'*B(:) (этот ответ) заняло только 0.0011 sec.
2. Использование поэлементного умножения sum(sum(A.*B)) ( John ) ответ) заняло 0.0035 sec.
3. Трассировка trace(A*B') (предложенная OP) заняла 0.054 sec.
4. Сумма диагонали sum(diag(A*B')) (опция отклонена John ) заняла 0.055 sec.

Возьмите домой сообщение: Matlab чрезвычайно эффективен, когда дело доходит до матричного / векторного продукта.Использование векторного внутреннего произведения в x3 раза быстрее даже по сравнению с эффективным решением поэлементного умножения.


Код эталонного теста Код, используемый для проверки времени выполнения

t=zeros(1,4);
n=1000; % size of matrices
it=100; % average results over XX trails
for ii=1:it, 
    % random inputs
    A=rand(n);
    B=rand(n); 
    % John's rejected solution
    tic; 
    n1=sum(diag(A*B'));
    t(1)=t(1)+toc;
    % element-wise solution
    tic;
    n2=sum(sum(A.*B));
    t(2)=t(2)+toc;
    % MOST efficient solution - using vector product
    tic;
    n3=A(:)'*B(:);
    t(3)=t(3)+toc;
    % using trace
    tic;
    n4=trace(A*B');
    t(4)=t(4)+toc;
    % make sure everything is correct
    assert(abs(n1-n2)<1e-8 && abs(n3-n4)<1e-8 && abs(n1-n4)<1e-8);
end;
t./it

Теперь вы можете запустить этот тест в клик .

...