Вот моя numba версия алгоритма im2col_2d для свертки .
@njit
def im2col_2d(mat, fil, res=1):
'''
Expects input and kernel to be square shape
Returns : im2col view with shape - (ker_sz,ker_sz,img_sz,img_sz)
'''
# Parameters
row_range = col_range = len(mat)-len(fil)+1
ker_sz=len(fil)
s0, s1 = mat.strides
shp = ker_sz,ker_sz,row_range,col_range
strd = s0,s1,s0,s1
out_view = np.lib.stride_tricks.as_strided(mat, shape=shp, strides=strd)
return out_view
Размер ввода : (1026,1026), Размер ядра : (3,3), Нет заполнения и Шаг 1, dtype: float64.
Этот код занимает незначительное время (< 1 микрос c) для выполнения; но он возвращает несмежный массив. Кажется, что мы не можем изменить форму несмежного массива внутри функции jumba jited. Но даже если я изменю форму массива снаружи, это займет около 12ms , даже функция 'copy ()' (внутренняя или внешняя функция) в этом массиве занимает 12 мс.
Окончательный размер будет (9, 1048576) после изменения формы.
a=a=im2col_2d(img_pad, fil,res=1)
a.shape
# (3, 3, 1024, 1024)
a.flags
#C_CONTIGUOUS : False
#F_CONTIGUOUS : False
#OWNDATA : False
#WRITEABLE : True
#ALIGNED : True
#WRITEBACKIFCOPY : False
#UPDATEIFCOPY : False
%timeit a.reshape((9,-1))
#100 loops, best of 3: 12.7 ms per loop
Если я делаю непрерывную копию и изменяем , это занимает незначительное время. Но этот процесс переформатирования снова занимает 12 мс.
b=a.copy # This itself take 12 ms
%timeit b.reshape((9,-1))
#1000000 loops, best of 3: 298 ns per loop
Поскольку numba не полностью поддерживает необычные функции индексирования и numpy, некоторые распространенные циклические реализации алгоритма, по-видимому, не могут быть реализованы. Есть ли лучший способ ускорить этот процесс (im2col в целом), используя numba (или быстрее в python)?
Ref: Реализация MATLAB's im2col 'slide в Python