Как я могу правильно извлечь вес из моей CNN? - PullRequest
0 голосов
/ 18 марта 2020

Прежде всего я обучил свою CNN-архитектуру:

adam =  optimizers.Adam(learning_rate=0.0001, beta_1=0.9, beta_2=0.999, amsgrad=False)
model = Sequential()
model.add(Conv2D(20, (3,3), activation='relu', input_shape =(5,5,1), padding='same', kernel_initializer='he_normal'))
model.add(Conv2D(30, (3,3), activation='relu', padding='same', kernel_initializer='he_normal'))
model.add(Dropout(0.5))
#model.add(MaxPooling2D(2,2)) # because the ROI is already small, we don't need subsampling
model.add(Flatten())
model.add(Dense(1, activation='sigmoid', kernel_initializer='he_normal'))
model.summary()
#plot_model(model, to_file='model.png', show_shapes=True, show_layer_names=True)
# compile the model
model.compile(loss='binary_crossentropy', optimizer= adam, metrics=['accuracy'])

history = model.fit(X_train, Y_train, epochs=200, callbacks=[model_checkpoint], batch_size=1, verbose=1, shuffle=True, validation_split=0.5)

В то же время я сохранил все веса за эпоху благодаря:

model_checkpoint=ModelCheckpoint('model_test{epoch:02d}.h5',save_freq=1,save_weights_only=True)

Затем я извлек свои веса, например, для первой эпохи "model_test01.h5"

import h5py
import numpy as np
def isGroup(obj):
    if isinstance(obj,h5py.Group):
        return True
    return False

def isDataset(obj):
    if isinstance(obj,h5py.Dataset):
        return True
    return False

def getDatasetFromGroup(datasets,obj):
    if isGroup(obj):
        for key in obj:
            x = obj[key]
            getDatasetFromGroup(datasets,x)
    else:
        datasets.append(obj)

def getWeightsForLayer(layerName, filename):
   weights = []
   with h5py.File(filename, mode='r') as f:
       for key in f:
           if layerName in key:
              obj = f[key]
              datasets = []
              getDatasetFromGroup(datasets,obj)

              for dataset in datasets:
                  w = np.array(dataset)
                  weights.append(w)
   return weights
           #print(key, f[key])
           #o = f[key]
           #for key1 in o:
               #print(key1,o[key1])
               #r = o[key1]
               #for key2 in r:
                   #print(key2,r[key2])
weights = getWeightsForLayer("conv2d_6","./model_test01.h5")
#for w in weights:
    #print(w.shape)
print(weights)

Но я не могу понять вывод, потому что список "весит" с двумя элементами float32 (в основном два numpy массива) первый с 20 элементов (предположение 20 - номер фильтра в первом слое свертки) и второй с размером (3,3,1,20) (поэтому невозможно открыть). Как я могу понять этот вывод?

1 Ответ

0 голосов
/ 19 марта 2020

Два «элемента с плавающей точкой», которые у вас есть, соответствуют весам фильтра и смещениям слоя конвона. Веса фильтра будут иметь форму (3, 3, 1, 20), а смещения будут иметь форму (20), потому что у вас есть 20 фильтров и одно значение смещения для каждого установщика. ,

(3, 3, 1, 20) представляется в виде (ширина_фильтра, высота_фильтра, глубина_фильтра, номер_фильтра)

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...