«Только один размер ввода может быть -1, а не 0 и 2» в Керасе - PullRequest
2 голосов
/ 19 апреля 2019

Это резюме моей модели.

Summary of my model

Моя модель в основном похожа на сверточную сеть.

Я хочу, чтобы моя модель работала независимо от ширины ввода,Таким образом, размер ширины выглядит как None.

, и я хочу прикрепить декодер к моей модели.Однако, когда я присоединяю декодер, возникает ошибка.(Если я не присоединяю декодер, программа работает нормально)

Эта часть - моя часть декодера (ниже)

if args.decoder==True:

    decoder = ConvCapsuleLayer(kernel_size=args.kernel_size, num_capsule=4, num_atoms=1, strides=1, padding='same',
                                  routings=1)(conv_cap)

    _, H, W, C, A = decoder.get_shape()

    y = layers.Input(shape=(n_class,))
    masked_by_y = Mask()([decoder, y])
    masked = Mask()(decoder)

    def shared_decoder(mask_layer):

        recon_1 = layers.Conv2DTranspose(4, (5,5), strides=(2, 2), padding='same',  kernel_initializer='he_normal', name='decoder_1', activation='relu')(mask_layer)
        recon_2 = layers.Conv2DTranspose(8, (5,5), strides=(2, 2), padding='same',  kernel_initializer='he_normal', name='decoder_2', activation='relu')(recon_1)
        recon_3 = layers.Conv2DTranspose(1, (1,1), strides=(1, 1), padding='same',  kernel_initializer='he_normal', name='decoder_3', activation='linear')(recon_2) 
        return recon_3

if args.decoder==True:
    train_model = models.Model(inputs=[x, y], outputs=[out_seg, shared_decoder(masked_by_y)]) # [x:image,y: mask] // [out_seg:length, reconstruction output]
    eval_model = models.Model(x, [out_seg, shared_decoder(masked)])
else:
    train_model = models.Model(inputs=x, outputs=out_seg)
    eval_model = models.Model(inputs=x, outputs=out_seg)
return train_model, eval_model

mask_1 - это мой слой Mask.

Если указана метка, возвращается только канал метки.(masked_by_y)

Если метка не указана, этот слой возвращает только канал с наибольшей суммой значений элемента в conv_capsule_layer_1.(masked)

Форма conv_capsule_layer_1 имеет вид (batch_size = None, height = 50, width = None, num_channel = 4, 1)

То есть mask layer возвращает канал, имеющий наибольшую сумму значений элементов среди четырех каналов.

Затем используйте Conv2DTranspose, чтобы сделать его равным размеру исходного ввода, используя возвращаемое значение (вывод слоя маски).

Однако возникает следующая ошибка

InvalidArgumentError (см. Выше для отслеживания): только один размер ввода может быть -1, а не 0 и 2 [[Node: mask_1 / Reshape_1 = Reshape [T = DT_FLOAT, Tshape = DT_INT32, _device = "/ job: localhost / replica: 0 / task: 0 / device: CPU: 0 "] (mask_1 / boolean_mask / Gather, mask_1 / Reshape_1 / shape)]]

Как сделать так, чтобы переменная длины неиспользуя -1?Я уже пробовал это non_zero_masked = K.reshape(non_zero,[-1, masked.shape[1], masked.shaped[2],1])

Это имя функции call в моем Mask слое

    def call(self, inputs, **kwargs):
    if type(inputs) is list:  # true label is provided with shape = [None, n_classes], i.e. one-hot code.
        assert len(inputs) == 2
        inputs, mask = inputs
        inputs = K.squeeze(inputs, axis=-1) # [batch, input_height, input_width, num_cap, num_atom] -> [batch, input_height, input_width, num_cap] 
    else:  # if no true label, mask by the max length of capsules. Mainly used for prediction

        inputs = K.squeeze(inputs, axis=-1) #[batch, input_height, input_width, num_cap]
        x = K.softmax(K.sqrt(K.sum(K.square(inputs), axis=(1,2)) + K.epsilon())) # x: [batch, 4]
        mask = K.one_hot(indices=K.argmax(x, 1), num_classes=x.get_shape().as_list()[1]) # mask: [batch,4]


    expand_mask = K.reshape(mask,[-1,1,1,mask.shape[1]]) #[batch_size, 1, 1, num_class]
    masked = inputs*expand_mask
    non_zero = tf.boolean_mask(masked, tf.not_equal(masked,0))
    non_zero_masked = K.reshape(non_zero,[-1, masked.shape[1], -1,1])
    return non_zero_masked

Кто-нибудь знает, почему эта ошибка происходит?Как я могу решить это?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...