Привет, я пытаюсь реализовать обратный максимальный пул в кератах.
from keras.layers import Input, Dense ,Reshape , Flatten
from keras.layers import Conv2D, MaxPooling2D, UpSampling2D,BatchNormalization, Conv2DTranspose
from keras.models import Model
from keras.datasets import mnist
class InverseMaxPool2D(Layer):
'''
This layer makes tied max pooling with the one before
'''
def __init__(self, maxpool,size,activation=None , **kwargs):
self.size = size
self.max_pool = maxpool
def build(self, input_shape):
super(InverseMaxPool2D, self).build(input_shape)
def call(self, incoming):
output = K.repeat_elements(incoming, self.size[0], axis=1)
output = K.repeat_elements(output, self.size[1], axis=2)
W = K.gradients(K.sum(self.max_pool.output), self.max_pool.input)
f = W * output
f = f[0]
return f
def compute_output_shape(self, input_shape):
return (input_shape[0],input_shape[1]*self.size[0],input_shape[2]*self.size[1],input_shape[3] )
t_input = Input( batch_shape=(50,28,28,1))
t_target = Input(batch_shape=(50,28,28,1),name='target_var')
network = keras.Sequential()
network.add(Conv2D( filters=32, kernel_size= (3,3), padding="same", activation=keras.activations.relu,kernel_initializer=keras.initializers.glorot_uniform(seed=42), data_format="channels_last",batch_input_shape = (50,28,28,1) ))
network.add(BatchNormalization())
network.add(MaxPooling2D( pool_size=(2,2)))
network.add(InverseMaxPool2D( network.layers[2], size = network.layers[2].pool_size))
network.add(Conv2DTranspose( filters=1, kernel_size=(3,3),padding='same', activation=keras.activations.relu, data_format="channels_last"))
learning_rate = 0.01
optimizer = keras.optimizers.SGD(lr = learning_rate , decay=9999e-4, nesterov=True)
recon_prediction_expression = network(t_input)
#encode_prediction_expression = encoder(t_input)
loss = keras.losses.mean_squared_error(recon_prediction_expression, t_target)
params = network.trainable_weights
updates = optimizer.get_updates(params= params, loss=loss)
trainAutoencoder = K.function([t_input, t_target], outputs = [loss], updates = updates)
predictReconstruction = K.function([t_input], [recon_prediction_expression])
Во время обучения я продолжаю получать
tenorflow / core / common_runtime / base_collective_executor. cc: 217] BaseCollectiveExecutor :: StartAbort Неверный аргумент: Вы должны передать значение для тензора заполнителя 'conv2d_1_input' с плавающей точкой dtype и формой [50,28,28,1] [[{{node conv2d_1_input}}]]
Кажется, что оно повторяется график расчета