Инициализация сверточного ядра Keras в виде массива numpy - PullRequest
0 голосов
/ 23 апреля 2020

Я хотел бы инициализировать веса для (5,5) сверточного слоя с четырьмя каналами как массив numpy. Вход в этот слой имеет форму (128,128,1). В частности, я хотел бы следующее:

def custom_weights(shape, dtype=None):
    matrix = np.zeros((1,5,5,4))
    matrix[0,2,2,0,0] = 1
    matrix[0,2,1,0,0] = -1

    matrix[0,2,2,0,1] = 1
    matrix[0,3,2,0,1] = -1

    matrix[0,2,2,0,2] = 2
    matrix[0,2,1,0,2] = -1
    matrix[0,2,3,0,2] = -1

    matrix[0,2,2,0,3] = 2
    matrix[0,1,2,0,3] = -1
    matrix[0,3,2,0,3] = -1
    weights = K.variable(matrix)
    return weights

input_shape = (128, 128, 1)
images = Input(input_shape, name='phi_input')

conv1 = Conv2D(4,[5, 5], use_bias = False, kernel_initializer=custom_weights, padding='valid', name='Conv2D_1', strides=1)(images)

Однако, когда я пытаюсь сделать это, я получаю ошибку

Depth of input (1) is not a multiple of input depth of filter (5) for 'Conv2D_1_19/convolution' (op: 'Conv2D') with input shapes: [?,128,128,1], [1,5,5,4].

Является ли моя ошибка в форме веса матрица

1 Ответ

1 голос
/ 23 апреля 2020

В вашем коде много несоответствий, ошибка, которую вы получаете не из данного кода, поскольку она даже не индексирует матрицу должным образом.

matrix = np.zeros((1,5,5,4))
matrix[0,2,2,0,0] = 1

Вы инициализируете numpy массив с 4 измерениями, но с использованием 5 индексов для изменения значения.

Ваши измерения для весов ядра неверны. Вот фиксированный код.

from tensorflow.keras.layers import *
from tensorflow.keras import backend as K
import numpy as np

def custom_weights(shape, dtype=None):
    kernel = np.zeros((5,5,1,4))
    # change value here
    kernel = K.variable(kernel)
    return kernel

input_shape = (128, 128, 1)
images = Input(input_shape, name='phi_input')

conv1 = Conv2D(4,[5, 5], use_bias = False, kernel_initializer=custom_weights, padding='valid', name='Conv2D_1', strides=1)(images)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...