Реализация гиперсети Keras? - PullRequest
0 голосов
/ 28 июня 2019

Какой самый простой способ реализации гиперсети в Керасе?То есть, где одна ветвь сети создает веса для другой?В частности, я хотел бы выполнить сопоставление с шаблоном, где я передаю шаблон в ветку CNN, которая генерирует сверточное ядро ​​для ножки, которая работает с основным изображением.Часть, в которой я не уверен, это то, где у меня есть слой CNN, который подается на весы извне, но градиенты все еще правильно проходят для тренировки.

1 Ответ

1 голос
/ 29 июня 2019

нога весов:

Для ноги весов просто создайте регулярную сеть, как вы это сделали бы с Keras.

Убедитесь, что его выходные данные имеют форму, подобную (spatial_kernel_size1, spatial_kernel_size2, input_channels, output_channels)

Используя функциональный API, вы можете создать несколько весов, например:

inputs = Input((imgSize1, imgSize2, imgChannels))

w1 = Conv2D(desired_channels, ....)(inputs)
w2 = Conv2D(desired_channels2, ....)(inputs or w1)
....

Здесь вы должны применить какое-то объединение, так как ваши выходы будут иметь огромный размер, и вам, вероятно, понадобятся фильтры с небольшими размерами, такими как 3, 5 и т. Д.

w1 = GlobalAveragePooling2D()(w1) #maybe GlobalMaxPooling2D
w2 = GlobalAveragePooling2D()(w2)

Если вы используете изображения фиксированного размера, вы также можете использовать другие виды объединения или сглаживания и плотного и т. Д.

Убедитесь, что вы изменили вес для правильной формы.

w1 = Reshape((size1,size2,input_channels, output_channels))(w1)
w2 = Reshape((sizeA, sizeB, input_channels2, output_channels2))(w2)
....

Выбор количества каналов зависит от вас

сверточная нога:

Теперь эта часть будет использовать только «не обучаемые» свертки, их можно найти непосредственно в бэкэнде и использовать в Lambda слоях:

out1 = Lambda(lambda x: K.conv2d(x[0], x[1]))([inputs,w1])
out2 = Lambda(lambda x: K.conv2d(x[0], x[1]))([out1,w2])

Теперь то, как вы собираетесь чередовать слои, сколько весов и т. Д., Также следует оптимизировать для себя.

Создать модель:

model = Model(inputs, out2)

Чередование

Вы также можете использовать выходные данные этой ветви в качестве входных данных для ветви генератора веса:

w3 = Conv2D(filters, ...)(out2)
w3 = GlobalAveragePooling2D()(w3)
w3 = Reshape((sizeI, sizeII, inputC, outputC))(w3)
out3 = Lambda(lambda x: K.conv2d(x[0], x[1]))([out2,w3])
...