Как установить веса и градиентные веса в слое непоследовательной () модели Keras - PullRequest
0 голосов
/ 25 июня 2019

У меня есть некоторые предварительно обученные веса (как для слоя, так и для градиента) в виде массивов Numpy, и мне нужно установить их в воссозданной сети.

Пример части моей сети:

X_input = Input((4,256,256))
# batchSize is 4
# size so far: (batchSize,4,256,256)

X = Conv2D(96,(11,11), strides=(4,4), data_format = 'channels_first')(X_input)

# output of the convolution has size: (batchSize, 96, 62, 62)

X = BatchNormalization(axis = 1)(X) 
X = Activation('relu')(X)
X = MaxPooling2D((3, 3), strides=(2, 2), data_format='channels_first')(X)

Массив чисел np.ar, который я должен установить в слое Conv2D, имеет форму: (96, 4, 11, 11)

На самом деле я могу вызывать функцию set_weights () как с помощью Sequential () модель как:

model.get_layer('layerName').set_weights(myNpArrayWeights)

Но если я это сделаю, это выдаст ошибку:

ValueError: You called `set_weights(weights)` on layer "step2_conv1" with a  
weight list of length 96, but the layer was expecting 2 weights. 
Provided weights: [[[[ 3.87499551e-03  1.32818555e-03  2.97062146e-0...

как если бы фигура была неправильной?

Поэтому я попытался ввести2 теста веса с использованием np.array([1,2]).Это сообщение об ошибке:

ValueError: Fetch argument <tf.Variable 'step2_conv1_4/kernel:0' 
shape=(11, 11, 4, 96) dtype=float32_ref> cannot be interpreted as a Tensor. 
(Tensor Tensor("step2_conv1_4/kernel:0", shape=(11, 11, 4, 96), dtype=float32_ref)
is not an element of this graph.) 

Как мне решить эту проблему?

Как я могу установить веса также для градиента?

Версия Python: 3.6.5
Версия Keras: 2.2.4
Версия Tensorflow: 1.13.1

EDIT

Для первой ошибки Value:

InУровень Conv2D установлен use_bias=False так, чтобы он ожидал только 1 массив весов, если use_bias установлен True, тогда дополнительный слой весов будет рассматриваться в слое.

Для второй ошибки ValueError:

Перед созданием экземпляра модели необходимо очистить сеанс, потому что вы, возможно, многократно запускали модель (как я), и, очевидно, Tensorflow запутывается в представленных множественных графах.

Для очистки сеанса выполните:

keras.backend.clear_session()
...