Оптимизация трех вложенных циклов с многократным вычислением в MATLAB - PullRequest
2 голосов
/ 20 февраля 2020

Для следующего кода я хочу оптимизировать его, используя шаблон, представленный в этом решении . Однако проблема заключается в том, как справиться со ссылкой на три вложенных цикла в одном выражении. Более того, состояние сильно отличается от этого поста.

подсказка: W и S являются NxN разреженными двойными матрицами.

    for i=1:N
    for j=1:N
        for k=1:N
            if W(j,k)~=0       
                temp(k)=S(i,j)-S(i,k); 
            end
        end
              sum_temp=max(temp)+sum_temp;
              temp=0;
    end
    B(i,i)=sum_temp;
    sum_temp=0;
end

1 Ответ

2 голосов
/ 22 февраля 2020

В этой ситуации я бы предпочел не векторизовать ваше решение. Вычисление S(i,j)-S(i,k) для каждой комбинации означало бы промежуточный результат размера [N, N, N]. Вместо этого я просмотрел ваш код и исключил как можно больше итераций, не увеличивая потребление памяти. Шаг за шагом, чтобы вы могли понять, как я там оказался.

N=30;
S=rand(N,N);
W=rand(N,N)<.1;
sum_temp=0;
temp=0;
%Your original code for reference
for i=1:N
    for j=1:N
        for k=1:N
            if W(j,k)~=0
                temp(k)=S(i,j)-S(i,k);
            end
        end
        sum_temp=max(temp)+sum_temp;
        temp=0;
    end
    B(i,i)=sum_temp;n
    sum_temp=0;
end
B_orig=B;
%1) you only want the max, no need to make temp a vector
for i=1:N
    sum_temp=0;
    for j=1:N
        temp=0;
        for k=1:N
            if W(j,k)~=0
                temp=max(temp,S(i,j)-S(i,k));
            end
        end
        sum_temp=temp+sum_temp;
    end
    B(i,i)=sum_temp;
end
assert(all(all(B==B_orig)))
%2) eliminate the outer loop
sum_temp=zeros(N,1);
for j=1:N
    temp=zeros(N,1);
    for k=1:N
        if W(j,k)~=0
            temp=max(temp,S(:,j)-S(:,k));
        end
    end
    sum_temp=temp+sum_temp;
end
B=diag(sum_temp);
assert(all(all(B==B_orig)))

%3) combine the inner loop with the condition
sum_temp=zeros(N,1);
for j=1:N
    temp=zeros(N,1);
    for k=find(W(j,:))
        temp=max(temp,S(:,j)-S(:,k));
    end
    sum_temp=temp+sum_temp;
end
B=diag(sum_temp);
assert(all(all(B==B_orig)))
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...