У меня есть некоторые предварительно обученные веса (как для слоя, так и для градиента) в виде массивов Numpy, и мне нужно установить их в воссозданной сети.
Пример части моей сети:
X_input = Input((4,256,256))
# batchSize is 4
# size so far: (batchSize,4,256,256)
X = Conv2D(96,(11,11), strides=(4,4), data_format = 'channels_first')(X_input)
# output of the convolution has size: (batchSize, 96, 62, 62)
X = BatchNormalization(axis = 1)(X)
X = Activation('relu')(X)
X = MaxPooling2D((3, 3), strides=(2, 2), data_format='channels_first')(X)
Массив чисел np.ar, который я должен установить в слое Conv2D, имеет форму: (96, 4, 11, 11)
На самом деле я могу вызывать функцию set_weights () как с помощью Sequential () модель как:
model.get_layer('layerName').set_weights(myNpArrayWeights)
Но если я это сделаю, это выдаст ошибку:
ValueError: You called `set_weights(weights)` on layer "step2_conv1" with a
weight list of length 96, but the layer was expecting 2 weights.
Provided weights: [[[[ 3.87499551e-03 1.32818555e-03 2.97062146e-0...
как если бы фигура была неправильной?
Поэтому я попытался ввести2 теста веса с использованием np.array([1,2])
.Это сообщение об ошибке:
ValueError: Fetch argument <tf.Variable 'step2_conv1_4/kernel:0'
shape=(11, 11, 4, 96) dtype=float32_ref> cannot be interpreted as a Tensor.
(Tensor Tensor("step2_conv1_4/kernel:0", shape=(11, 11, 4, 96), dtype=float32_ref)
is not an element of this graph.)
Как мне решить эту проблему?
Как я могу установить веса также для градиента?
Версия Python: 3.6.5
Версия Keras: 2.2.4
Версия Tensorflow: 1.13.1
EDIT
Для первой ошибки Value:
InУровень Conv2D установлен use_bias=False
так, чтобы он ожидал только 1 массив весов, если use_bias установлен True, тогда дополнительный слой весов будет рассматриваться в слое.
Для второй ошибки ValueError:
Перед созданием экземпляра модели необходимо очистить сеанс, потому что вы, возможно, многократно запускали модель (как я), и, очевидно, Tensorflow запутывается в представленных множественных графах.
Для очистки сеанса выполните:
keras.backend.clear_session()