Tensorflow Keras: градиент выходов по отношению к входам внутри настраиваемой функции потерь - PullRequest
0 голосов
/ 05 мая 2020

Я хочу написать пользовательскую функцию потерь для многослойной сети Perceptron в Keras. Потеря состоит из двух компонентов: первая - это обычное «mse», а вторая - поэлементные градиенты выходных данных по отношению к входным характеристикам. Пусть x будет входом с двумя характеристиками (размер: количество выборок X 2) и y выходом с одним выходом (размер: количество выборок X 1). Я обозначаю производную каждого выходного образца с первой особенностью каждого образца как $\frac{dy[:]}{dx[:,0]}$

. Точно так же я хочу вычислить следующее выражение внутри функции потерь:

$$r[:] = y[:] \frac{dy[:]}{dx[:,0]} - x[:,1] \frac{d^2y[:]}{dx[:,0]^2}$$

и возьмите средний квадрат вектора r. Общая потеря представляет собой сумму обычного mse и среднего квадрата r вектора.

Это минимальный воспроизводимый пример кода, который я пробовал:

import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
import tensorflow as tf
import tensorflow.keras.backend as kb

def custom_loss_envelop(model_inputs, model_outputs):

    def custom_loss(y_true,y_pred):
        mse_loss = keras.losses.mean_squared_error(y_true, y_pred)
        print()
        print(model_inputs); print()
        print(model_outputs); print()
        dy_dx = kb.gradients(model_outputs, tf.gather(model_inputs, [0], axis=1))
        print(dy_dx); print()
        d2y_dx2 = kb.gradients(dy_dx, tf.gather(model_inputs, [0], axis=1))
        print(d2y_dx2); print()

        r = tf.multiply(model_outputs, tf.gather(dy_dx, [0], axis=1)) - tf.multiply(tf.gather(model_inputs, [1], axis=1), tf.gather(d2y_dx2, [0], axis=1)) # y*dy_dx[0] - x[1]*d2y_dx[0]2

        r = kb.mean(kb.square(r))
        loss = mse_loss + r
        return loss

    return custom_loss

nx=100;
inputs_train=np.random.uniform(0,1,(nx,2)); outputs_train=np.random.uniform(0,1,(nx,1))
inputs_val=np.random.uniform(0,1,(int(nx/2),2)); outputs_val=np.random.uniform(0,1,(int(nx/2),1))
n_hidden_units=50; l2_reg_lambda=0; learning_rate=0.001; dropout_factor=0.0; epochs=3

model = keras.Sequential();
model.add(keras.layers.Dense(n_hidden_units, activation='relu', input_shape=(inputs_train.shape[1],), kernel_regularizer=keras.regularizers.l2(l2_reg_lambda))); #first hidden layer
model.add(keras.layers.Dropout(dropout_factor)); model.add(keras.layers.BatchNormalization())
model.add(keras.layers.Dense(n_hidden_units, activation='relu', kernel_regularizer = keras.regularizers.l2(l2_reg_lambda)));
model.add(keras.layers.Dropout(dropout_factor)); model.add(keras.layers.BatchNormalization())
model.add(keras.layers.Dense(n_hidden_units, activation='relu', kernel_regularizer = keras.regularizers.l2(l2_reg_lambda)));
model.add(keras.layers.Dropout(dropout_factor)); model.add(keras.layers.BatchNormalization())
model.add(keras.layers.Dense(outputs_train.shape[1], activation='linear'));
optimizer1 = keras.optimizers.Adam(lr=learning_rate, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=True)

model.compile(loss=custom_loss_envelop(model.inputs, model.outputs), optimizer=optimizer1, metrics=['mse'])

model.fit(inputs_train, outputs_train, batch_size=100, epochs=epochs, shuffle=True, validation_data=(inputs_val,outputs_val), verbose=1)

Здесь , Я произвольно сгенерировал образцы для обучения и проверки. Я получаю следующие тензорные формы: model_inputs: [<tf.Tensor 'dense_input:0' shape=(None, 2) dtype=float32>], model_outputs: [<tf.Tensor 'dense_3/Identity:0' shape=(None, 1) dtype=float32>] и dy_dx: [None]. Первые 2 такие, как ожидалось, но производная также должна иметь форму (None, 1), но это не так. Следовательно, я получаю AttributeError: 'NoneType' object has no attribute 'op' ошибку в строке d2y_dx2 = kb.gradients(dy_dx, tf.gather(model_inputs, [0], axis=1))

Любая помощь приветствуется либо для устранения этой проблемы, либо с альтернативным решением.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...