Я пытаюсь реализовать настраиваемую функцию потерь с использованием MSE полей градиента, которая может быть легко реализована с помощью np.gradient ().
Но когда я интегрирую эту функцию потерь в свою модель keras, она не будет работать.Причина в том, что np.gradient () принимает входные данные np.array, поэтому мне пришлось преобразовать тензор в массив с помощью K.eval ().Однако такое преобразование недопустимо, поскольку эти тензоры не имеют конкретных значений в момент построения модели, а являются просто заполнителями.
Есть ли решение проблемы?
ps здесь похожая проблема, но решение для моего случая бесполезно: Преобразование Tensor в np.array с использованием K.eval () в Keras возвращает InvalidArgumentError
def comput_loss(x):
y_true, y_pred = x
# Compute the Perceptual loss ###based on GRADIENT-field MSE
grad_true = np.gradient(K.eval(y_true),axis=1)
grad_pred = np.gradient(K.eval(y_pred),axis=1)
grad_loss =
tf.losses.mean_squared_error( grad_true,grad_pred)
return [grad_loss]
# Input LR images
img_lr = Input(shape=self.shape_lr,batch_shape=self.batch_shape_lr)
img_hr = Input(shape=self.shape_hr,batch_shape=self.batch_shape_hr)
# Create a high resolution image from the low resolution one
generated_hr = self.generator(img_lr)
# In the combined model we only train the generator
self.discriminator.trainable = False
self.RaGAN.trainable = False
# Output tensors to a Model must be the output of a Keras `Layer`
total_loss = Lambda(comput_loss, name='comput_loss')([img_hr, generated_hr])
percept_loss = Lambda(lambda x: self.loss_weights['percept'] * name='grad_loss')(total_loss)
# Create model
model = Model(inputs=[img_lr, img_hr], outputs=[grad_loss])
# Add the loss of model and compile
# model.add_loss(loss)
model.add_loss(percept_loss)
model.compile(optimizer=Adam(self.gen_lr))
# Create metrics of ESRGAN
model.metrics_names.append('percept_loss')
model.metrics_tensors.append(percept_loss)
return model
tenorflow.python.framework.errors_impl.InvalidArgumentError: Необходимо передать значение для тензора-заполнителя 'input_6' с плавающей точкой dtype и формой [2,16,16,16,1]