График отключается при добавлении слоев поверх сети ResNet - PullRequest
0 голосов
/ 07 февраля 2019

Я пытаюсь изменить форму входа сети ResNet50.Мне нужны входы с более чем 3 каналами.Приложение ResNet работает, когда вы указываете форму ввода без загрузки весов imagenet, но я хотел бы использовать веса imagenet, чтобы избежать длительной фазы обучения.

Я знаю, что веса imagenet предназначены для формы ввода с тремя каналамино теоретически, обрезав головку сети и добавив новый входной слой, это должно сработать.

Я попытался удалить слой заголовка, но у меня возникли некоторые проблемы, сказав, что количество фильтров отличается от 3

* 1006.*

ValueError: количество входных каналов не соответствует соответствующему размеру фильтра, 6! = 3

    model=keras.applications.resnet50.ResNet50(include_top=False,
               input_shape(200,200,3),weights='imagenet')
    model.layers.pop(0)
    model.layers.pop(0)
    model.layers.pop()
    X_input = Input((200,200,6), name='input_1')
    X = ZeroPadding2D((3, 3), name='conv1_pad')(X_input)
    model = Model(inputs=X, outputs=model(X))
    model.summary()

Я думаю, что можно изменить количество каналов формы входа ивсе еще использую веса от imagenet, но метод, который я попробовал, кажется неправильным.

1 Ответ

0 голосов
/ 08 февраля 2019

Я не уверен, что модель keras поддерживает операции со списками на своих слоях, кажется, что всплывающие слои не заставляют ее забыть ожидаемый размер ввода.

Вы можете инициализировать новую повторную сеть с вашей формой ввода и вручную загрузить веса Imagenet для всех слоев, кроме первых 3, которые ожидают 3 канала в своем входном тензоре.

заимствование нескольких строк из keras.applications.resnet50 приведет к чему-то вроде этого:

import h5py
import keras
from keras_applications.resnet50 import WEIGHTS_PATH_NO_TOP

input_tensor = keras.Input((200, 200, 6))
resnet = keras.applications.ResNet50(
    input_tensor=input_tensor, weights=None, include_top=False
)

weights_path = keras.utils.get_file(
    'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5',
    WEIGHTS_PATH_NO_TOP,
    cache_subdir='models',
    md5_hash='a268eb855778b3df3c7506639542a6af')

with h5py.File(weights_path, 'r') as f:
    for layer in resnet.layers[3:]:
        if layer.name in f:
            layer.set_weights(f[layer.name].values())

С учетом вышесказанного, тип передаваемого обучения, которое вы пытаетесь сделать, не очень распространен, и яЯ действительно любопытно, если это работаетМожете ли вы обновить, если оно действительно сошлось быстрее?

...