Очень медленное выполнение пользовательской функции свертки для нейронной сети в MATLAB - PullRequest
0 голосов
/ 26 августа 2018

У меня есть реализация нейронной сети свертки в MATLAB (из открытого источника DeepLearnToolbox). Следующий код находит свертку различных весов и параметров:

 z = z + convn(net.layers{l - 1}.a{i}, net.layers{l}.k{i}{j}, 'valid');

Чтобы обновить инструмент, я реализовал собственную свертку на основе схемы с фиксированной запятой, используя следующий код:

function result = convolution(image, kernal)

% find dimensions of output
row = size(image,1) - size(kernal,1) + 1;
col = size(image,2) - size(kernal,2) + 1;
zdim = size(image,3);

%create output matrix
output = zeros(row, col);

% flip the kernal
kernal_flipped = fliplr(flipud(kernal));

%find rows and col of kernal for loop iteration
row_ker = size(kernal_flipped,1);
col_ker = size(kernal_flipped,2);

for k = 1 : zdim
    for i = 0 : row-1
        for j = 0 : col-1
            sum = fi(0,1,8,7);
             prod = fi(0,1,8,7);
            for k_row = 1 : row_ker
                for k_col = 1 : col_ker
                    a = image(k_row+i, k_col+j, k);
                    b = kernal_flipped(k_row,k_col);
                    prod = a * b;
                   % convert to fixed point                     
                    prod = fi((product/16384), 1, 8, 7);

                    sum = fi((sum + prod), 1, 8, 7);
                end
            end
            output(i+1, j+1, k) = sum;
        end
    end
end

result = output;
end

Проблема в том, что когда я использую свою сверточную реализацию в более крупном приложении, она работает очень медленно. Есть предложения как улучшить время его выполнения?

1 Ответ

0 голосов
/ 26 августа 2018

MATLAB не поддерживает двухмерную свертку с фиксированной точкой, но, зная, что свертка может быть записана как матричное умножение и что MATLAB поддерживает умножение матрицы с фиксированной точкой , вы можете использовать im2col преобразовать изображение в формат столбца и умножить его на ядро, чтобы сверить их.

row = size(image,1) - size(kernal,1) + 1;
col = size(image,2) - size(kernal,2) + 1;
zdim = size(image,3);

output = zeros(row, col);

kernal_flipped = fliplr(flipud(kernal));

fi_kernel = fi(kernal_flipped(:).', 1, 8, 7) / 16384;   

sz = size(kernal_flipped);
sz_img = size(image);

% Use the generated indexes to convert the image into column format
idx_col = im2col(reshape(1:numel(image)/zdim,sz_img(1:2)),sz,'sliding');
image = reshape(image,[],zdim);

for k = 1:zdim
    output(:,:,k) = double(fi_kernel * reshape(image(idx_col,k),size(idx_col)));
end
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...