Реализация "полной свертки" для нахождения градиента по отношению к входам слоя свертки - PullRequest
0 голосов
/ 14 января 2020

Я пытался реализовать "полную свертку" относительно входов уровня свертки. Согласно этой статье, это выглядит так:

enter image description here

Итак, я написал эту функцию:

def full_convolve(filters, gradient):
    filters = np.ones((5,5))
    gradient = np.ones((8,8))
    result = list()
    output_shape = 12
    filter_r = filters.shape[0] - 1
    filter_c = filters.shape[1] - 1
    gradient_r = gradient.shape[0] - 1
    gradient_c = gradient.shape[1] - 1

    for i in range(0,output_shape):
        if (i <= filter_r):
            row_slice = (0, i + 1)
            filter_row_slice = ( 0 , i + 1)
        elif ( i > filter_r and i <= gradient_r):
            row_slice = (i - filter_r, i + 1)
            filter_row_slice = (0, i + 1)
        else: 
            rest = ((output_shape - 1) -  i )
            row_slice = (gradient_r  - rest, i + 1 )
            filter_row_slice = (0 ,rest + 1)
        for b in range(0,output_shape):
            if (b <= filter_c):
                col_slice = (0, b + 1)
                filter_col_slice = (0, b+1)
            elif (b > filter_c and b <= gradient_c):
                col_slice = (b - filter_c, b + 1)
                filter_col_slice = (0,b+1)
            else:
                rest = (output_shape - 1 ) - b 
                col_slice = (gradient_r - rest , b + 1)
                filter_col_slice = (0, rest + 1)
            r = np.sum(gradient[row_slice[0] : row_slice[1], col_slice[0] : col_slice[1]] * filters[filter_row_slice[0]: filter_row_slice[1], filter_col_slice[0]: filter_col_slice[1]])
            result.append(r)
    result = np.asarray(result).reshape(12,12)

Я проверил это с единицами, и вывод кажется правильным (если я правильно понял "полную свертку"):

[[ 1.  2.  3.  4.  5.  5.  5.  5.  4.  3.  2.  1.]
 [ 2.  4.  6.  8. 10. 10. 10. 10.  8.  6.  4.  2.]
 [ 3.  6.  9. 12. 15. 15. 15. 15. 12.  9.  6.  3.]
 [ 4.  8. 12. 16. 20. 20. 20. 20. 16. 12.  8.  4.]
 [ 5. 10. 15. 20. 25. 25. 25. 25. 20. 15. 10.  5.]
 [ 5. 10. 15. 20. 25. 25. 25. 25. 20. 15. 10.  5.]
 [ 5. 10. 15. 20. 25. 25. 25. 25. 20. 15. 10.  5.]
 [ 5. 10. 15. 20. 25. 25. 25. 25. 20. 15. 10.  5.]
 [ 4.  8. 12. 16. 20. 20. 20. 20. 16. 12.  8.  4.]
 [ 3.  6.  9. 12. 15. 15. 15. 15. 12.  9.  6.  3.]
 [ 2.  4.  6.  8. 10. 10. 10. 10.  8.  6.  4.  2.]
 [ 1.  2.  3.  4.  5.  5.  5.  5.  4.  3.  2.  1.]]

Однако мне не нравятся все эти ручные проверки и операторы if / else. Я чувствую, что есть лучший способ реализовать это в NumPy (возможно, используя некоторые дополнения нулями или что-то вроде этого). Кто-нибудь может предложить лучший подход? Спасибо

...