Как избежать циклов для умножения всех перестановок матрицы? - PullRequest
0 голосов
/ 04 ноября 2019

У меня есть следующий код:

  N=8;
  K=10;
  a=zeros(1,N^(K-1));
  b=zeros(1,N^(K-1));

  for ii=1:K
    p0{ii}=rand(1,N);
    p1{ii}=rand(1,N);
  end

  k=1;
  for j1=1:N
    for j3=1:N
      for j4=1:N
        for j5=1:N
          for j6=1:N
            for j7=1:N
              for j8=1:N
                for j9=1:N
                  for j10=1:N
                    a(k)=p0{1}(j1)*p0{3}(j3)*p0{4}(j4)*p0{5}(j5)*p0{6}(j6)*p0{7}(j7)*p0{8}(j8)*p0{9}(j9)*p0{10}(j10);
                    b(k)=p1{1}(j1)*p1{3}(j3)*p1{4}(j4)*p1{5}(j5)*p1{6}(j6)*p1{7}(j7)*p1{8}(j8)*p1{9}(j9)*p1{10}(j10);
                    k=k+1;
                  end
                end
              end
            end
          end
        end
      end
    end
    end

Я не могу оценить этот код для N=8, потому что это занимает много времени. p0 и p1 - матрицы размером KxN. Вложенный цикл for пропускает одну строку из p0 и p1, здесь вторая строка соответствует индексу j2. Остальные элементы матрицы умножаются друг на друга. Таким образом, в общей сложности есть N^(K-1) умножений для получения векторов a и b.

Есть ли способ сделать это без использования циклов или, по крайней мере, в некоторых разумныхвремя?

1 Ответ

2 голосов
/ 04 ноября 2019

По сути, вы просто умножаете каждый элемент из каждой ячейки p0 (или p1) друг на друга. Используя немного магии из reshape и поэлементного умножения , это можно упростить до одного цикла.

Давайте посмотрим на следующий код:

N = 3;
K = 10;

for ii = 1:K
  p0{ii} = rand(1, N);
  p1{ii} = rand(1, N);
end  

a = zeros(1, N^(K-1));
b = zeros(1, N^(K-1));

%for ii = 1:K
%  p0{ii} = rand(1, randi(N));
%  p1{ii} = rand(1, randi(N));
%end

tic;
k = 1;
for j1 = 1:N
  for j3 = 1:N
    for j4 = 1:N
      for j5 = 1:N
        for j6 = 1:N
          for j7 = 1:N
            for j8 = 1:N
              for j9 = 1:N
                for j10 = 1:N
                  a(k) = p0{1}(j1)*p0{3}(j3)*p0{4}(j4)*p0{5}(j5)*p0{6}(j6)*p0{7}(j7)*p0{8}(j8)*p0{9}(j9)*p0{10}(j10);
                  b(k) = p1{1}(j1)*p1{3}(j3)*p1{4}(j4)*p1{5}(j5)*p1{6}(j6)*p1{7}(j7)*p1{8}(j8)*p1{9}(j9)*p1{10}(j10);
                  k = k+1;
                end
              end
            end
          end
        end
      end
    end
  end
end
toc;

tic;
aa = p0{1};
bb = p1{1};
% For MATLAB versions R2016 and newer:
for jj = 3:K
  aa = reshape(aa .* p0{jj}.', 1, numel(aa) .* numel(p0{jj}));
  bb = reshape(bb .* p1{jj}.', 1, numel(bb) .* numel(p1{jj}));
end
% For MATLAB versions before R2016b: 
%for jj = 3:K
%  aa = reshape(bsxfun(@times, aa, p0{jj}.'), 1, numel(aa) .* numel(p0{jj}));
%  bb = reshape(bsxfun(@times, bb, p1{jj}.'), 1, numel(bb) .* numel(p1{jj}));
%end
toc;

numel(find(aa ~= a))
numel(find(bb ~= b))

Вывод:

Elapsed time is 2.39744 seconds.
Elapsed time is 0.00070405 seconds.
ans = 0
ans = 0

Кажется, a и aa, а также b и bb фактически равны, и предлагаемое решение намного быстрее. Я протестировал N = 8 только для своего решения:

Elapsed time is 1.54249 seconds.

Если вы замените инициализацию p0 и p1, раскомментировав соответствующие строки, вы увидите, что мое решение также допускает переменную длину длякаждая p0 (или p1) ячейка. Обратите внимание: это не работает для вашего первоначального решения из-за жесткого кодирования, поэтому сравнение здесь невозможно.

Также обратите внимание, что jj = 3:N здесь также жестко закодировано. Если другие части должны быть пропущены, это необходимо изменить соответствующим образом!

Надеюсь, это поможет!

...