Как сделать связанный maxpooling в керасе / - PullRequest
0 голосов
/ 07 апреля 2020

Привет, я пытаюсь реализовать обратный максимальный пул в кератах.

from keras.layers import Input, Dense ,Reshape , Flatten
from  keras.layers import  Conv2D, MaxPooling2D, UpSampling2D,BatchNormalization, Conv2DTranspose
from  keras.models import Model
from keras.datasets import mnist

class InverseMaxPool2D(Layer):
    '''
    This layer  makes tied max pooling with the one before
    '''

    def __init__(self, maxpool,size,activation=None , **kwargs):
        
      
        self.size = size
        self.max_pool = maxpool
       
    def build(self, input_shape):
        
        super(InverseMaxPool2D, self).build(input_shape)

    def call(self, incoming):
        
      
        output = K.repeat_elements(incoming, self.size[0], axis=1)
        output = K.repeat_elements(output, self.size[1], axis=2)

      
        W =  K.gradients(K.sum(self.max_pool.output), self.max_pool.input)
        
        f = W * output
        f = f[0]

       
        return f
    def compute_output_shape(self, input_shape):
        return (input_shape[0],input_shape[1]*self.size[0],input_shape[2]*self.size[1],input_shape[3] )    



t_input = Input( batch_shape=(50,28,28,1))

t_target = Input(batch_shape=(50,28,28,1),name='target_var')



network = keras.Sequential()

network.add(Conv2D( filters=32, kernel_size= (3,3), padding="same", activation=keras.activations.relu,kernel_initializer=keras.initializers.glorot_uniform(seed=42),  data_format="channels_last",batch_input_shape = (50,28,28,1) ))
network.add(BatchNormalization())

network.add(MaxPooling2D( pool_size=(2,2)))
               

network.add(InverseMaxPool2D( network.layers[2], size = network.layers[2].pool_size))

network.add(Conv2DTranspose( filters=1, kernel_size=(3,3),padding='same', activation=keras.activations.relu,  data_format="channels_last"))                  



learning_rate = 0.01

optimizer = keras.optimizers.SGD(lr = learning_rate , decay=9999e-4, nesterov=True)

recon_prediction_expression = network(t_input)

#encode_prediction_expression = encoder(t_input)

loss = keras.losses.mean_squared_error(recon_prediction_expression, t_target)
params = network.trainable_weights

updates = optimizer.get_updates(params= params, loss=loss)

trainAutoencoder =  K.function([t_input, t_target], outputs = [loss], updates = updates)
predictReconstruction = K.function([t_input], [recon_prediction_expression])

Во время обучения я продолжаю получать

tenorflow / core / common_runtime / base_collective_executor. cc: 217] BaseCollectiveExecutor :: StartAbort Неверный аргумент: Вы должны передать значение для тензора заполнителя 'conv2d_1_input' с плавающей точкой dtype и формой [50,28,28,1] [[{{node conv2d_1_input}}]]

Кажется, что оно повторяется график расчета

...