Как удалить определенные c фильтры в CNN keras - PullRequest
0 голосов
/ 06 мая 2020

Предположим, я создал модель следующим образом:

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Conv2D

model = Sequential()

model.add(Conv2D(32, kernel_size=(3, 3),
                     activation='relu',
                     input_shape=input_shape,
                     kernel_initializer='he_normal',))
model.add(Flatten())
model.add(Dense(10, activation='softmax'))

После того, как я закончил обучение модели, как я могу удалить фильтры с индексами 1,5 и 9? Таким образом, общее количество оставшихся фильтров будет 29, но без исходных фильтров, расположенных в точках 1,5 и 9.

Я хотел бы использовать эту «модифицированную модель» для прогнозирования тестовых данных еще раз, но без них. фильтры

score = modified_model.evaluate(x, y)

1 Ответ

1 голос
/ 06 мая 2020

это решение для замены желаемых фильтров в обученном net нулями

inp = Input((10,10,3))
c = Conv2D(32, kernel_size=(3, 3),
           activation='relu',
           kernel_initializer='he_normal')
f = Flatten()
d = Dense(10, activation='softmax')

x = c(inp)
x = f(x)
out = d(x)
model = Model(inp, out)
print(model.summary())

model.fit(.....)

w,b = c.get_weights()
w[:,:,:,1] = 0
w[:,:,:,5] = 0
w[:,:,:,9] = 0
c.set_weights([w,b])

о модификации обученного net удаление весов невозможно. он несовместим со слоем ниже в вашем случае плоский и плотный

w,b = c.get_weights()
w = np.delete(w, [1,5,9], -1)
b = np.delete(b, [1,5,9], 0)


new_c = Conv2D(29, kernel_size=(3, 3),
               activation='relu',
               kernel_initializer='he_normal',
               trainable=False)

x = new_c(inp)
x = f(x)
out = d(x) # -----> error!
new_model= Model(inp, out)

new_c.set_weights([w,b])

print(new_model.summary())

вы можете создать новую сеть, в которой вы управляете старыми фильтрами conv2d, но вам нужно переобучить слой ниже

w,b = c.get_weights()
w = np.delete(w, [1,5,9], -1)
b = np.delete(b, [1,5,9], 0)

new_inp = Input((10,10,3))
new_c = Conv2D(29, kernel_size=(3, 3),
           activation='relu',
           kernel_initializer='he_normal',
           trainable=False)
new_f = Flatten()
new_d = Dense(10, activation='softmax')

new_x = new_c(new_inp)
new_x = new_f(new_x)
new_out = new_d(new_x)
new_model = Model(new_inp, new_out)

new_c.set_weights([w,b])

print(new_model.summary())

new_model.fit(.....)
...