В Tensorflow - Можно ли заблокировать определенные c фильтры свертки в слое или удалить их вообще? - PullRequest
2 голосов
/ 04 апреля 2020

Когда я использую трансферное обучение в Tensorflow, я знаю, что можно заблокировать слои от дальнейшего обучения, выполнив:

for layer in pre_trained_model.layers:
    layer.trainable = False

Возможно ли вместо этого заблокировать указанные c фильтры в слое? Например, если весь слой содержит 64 фильтра, возможно ли:

  • заблокировать только некоторые из них, которые, по-видимому, содержат разумные фильтры, и повторно обучить те, которые не имеют?

ИЛИ

  • удалить неоправданно выглядящие фильтры из слоев и переобучить без них? (например, чтобы увидеть, сильно ли изменятся переобученные фильтры)

1 Ответ

3 голосов
/ 04 апреля 2020

Одним из возможных решений является реализация пользовательского уровня, который разбивает свертку на отдельные number of filters свертки и устанавливает для каждого канала (который является сверткой с одним выходным каналом) значение trainable или not trainable. Например:

import tensorflow as tf
import numpy as np

class Conv2DExtended(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, **kwargs):
        self.filters = filters
        self.conv_layers = [tf.keras.layers.Conv2D(1, kernel_size, **kwargs) for _ in range(filters)]
        super().__init__()

    def build(self, input_shape):
        _ = [l.build(input_shape) for l in self.conv_layers]
        super().build(input_shape)

    def set_trainable(self, channels):
        """Sets trainable channels."""
        for i in channels:
            self.conv_layers[i].trainable = True

    def set_non_trainable(self, channels):
        """Sets not trainable channels."""
        for i in channels:
            self.conv_layers[i].trainable = False

    def call(self, inputs):
        results = [l(inputs) for l in self.conv_layers]
        return tf.concat(results, -1)

И пример использования:

inputs = tf.keras.layers.Input((28, 28, 1))
conv = Conv2DExtended(filters=4, kernel_size=(3, 3))
conv.set_non_trainable([1, 2]) # only channels 0 and 3 are trainable
res = conv(inputs)
res = tf.keras.layers.Flatten()(res)
res = tf.keras.layers.Dense(1, activation=tf.nn.sigmoid)(res)

model = tf.keras.models.Model(inputs, res)
model.compile(optimizer=tf.keras.optimizers.SGD(),
              loss='binary_crossentropy',
              metrics=['accuracy'])
model.fit(np.random.normal(0, 1, (10, 28, 28, 1)),
          np.random.randint(0, 2, (10)),
          batch_size=2,
          epochs=5)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...