Как мне скопировать определенные веса слоев из предварительно обученных моделей, используя Tensorflow Keras api? - PullRequest
0 голосов
/ 03 мая 2019

Я пытаюсь обучить виртуальную сеть, которая использует 4-канальный вход, и хочу использовать предварительно обученную модель, такую ​​как VGG16.Имеет смысл, что я не должен использовать начальные блоки конв из VGG16, так как они обучены для 3-канальных входов, и переопределять начальные блоки конв.

Однако я хочу использовать block3 и дальше от VGG16.Как мне добиться этого с помощью Tensorflow Keras api?

Короче говоря, как мне копировать веса из определенных слоев из предварительно обученных моделей.Я использую альфа-версию tennsflow 2.0.

1 Ответ

2 голосов
/ 03 мая 2019

Быстрый способ сделать это - создать новую модель, сочетающую в себе пользовательский ввод и последние слои VGG16. Найдите индекс первого слоя VGG16, который вы хотите сохранить, и подключите его к вновь созданному входу. Затем подключите каждый следующий слой VGG16 вручную, чтобы воссоздать сегмент VGG16. По пути вы можете заморозить слои VGG16.

from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.vgg16 import preprocess_input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D

vgg16 = VGG16()

# Find the index of the first block3 layer
for index in range(len(vgg16.layers)):
    if 'block3' in vgg16.layers[index].name:
        break

# Add your own input
model_input = Input(shape=(224,224,4), name='new_input')
x = Conv2D(...)(model_input)
...

# Connect your last layer to the VGG16 model, starting at the "block3" layer
# Then, you need to connect every layer manually in a for-loop, freezing each layer along the way

for i in range(index, len(vgg16.layers)):
  # freeze the VGG16 layer
  vgg16.layers[i].trainable = False  

  # connect the layer
  x = vgg16.layers[i](x)

model_output = x
newModel = Model(model_input, model_output)

Также убедитесь, что выходные данные ваших пользовательских слоев соответствуют форме, которую слои block3 ожидают в качестве входных данных.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...