Как сложить два динамически сформированных слоя в Keras? - PullRequest
0 голосов
/ 11 июля 2019

Я пытаюсь реализовать полностью сверточную нейронную сеть в Керасе, которая требует, чтобы сеть принимала входные данные произвольных пространственных измерений.В какой-то момент мне нужно сложить выходы двух разных слоев, но я не могу, потому что большую часть времени выходы имеют разные пространственные измерения.Поэтому я пытаюсь обрезать выходные данные, чтобы сделать их одинаковыми по размеру, прежде чем добавлять их, но чтобы узнать, сколько их нужно обрезать, мне нужно знать разницу между пространственными измерениями, передаваемыми в качестве аргумента в Cropping2D, чтоЯ не знаю, как получить, так как размеры являются динамическими.Как я могу это сделать?

Вот как я пытаюсь обрезать слои, но я застрял, потому что Tensorflow говорит мне, что я не могу сравнить два тензора:

from keras.models import Model
from keras.layers import Conv2D, MaxPooling2D, Conv2DTranspose
from keras.layers import Lambda, Add, Input, Dropout, Activation
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import CSVLogger, ModelCheckpoint

import os
import matplotlib.pyplot as plt
import numpy as np

import tensorflow as tf
from keras import backend as K

def crop(layers):
    o1, o2 = layers
    output_height1, output_width1 = K.shape(o1)[1], K.shape(o2)[2]
    output_height2, output_width2 = K.shape(o2)[1], K.shape(o2)[2]

    cx = abs( output_width1 - output_width2 )
    cy = abs( output_height2 - output_height1 )

    if output_width1 > output_width2:
            o1 = Cropping2D( cropping=((0,0) ,  (  0 , cx )))(o1)
    else:
            o2 = Cropping2D( cropping=((0,0) ,  (  0 , cx )))(o2)

    if output_height1 > output_height2 :
            o1 = Cropping2D( cropping=((0,cy) ,  (  0 , 0 )))(o1)
    else:
            o2 = Cropping2D( cropping=((0, cy ) ,  (  0 , 0 )))(o2)

    return o1 , o2

Вот соответствующая часть в моей сети (o и o2 - это выходы двух предыдущих слоев keras):

o , o2 = Lambda(crop)([o , o2])
o = Add()([ o , o2 ])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...