Как написать model.predict () с функциональным API Keras в пользовательской функции потерь? - PullRequest
0 голосов
/ 02 апреля 2020

Мне нужно создать собственную функцию потерь для модели Keras. Основная проблема заключается в использовании другой (предварительно обученной) модели при кастомной потере. Чтобы получить данные от ввода, я передаю входной тензор в свою пользовательскую функцию потерь.

def custom_loss_wrapper(self, input_tensor):
    my_model, E, _lambda = self._get_model()

    def my_loss(y_true, y_pred):
        model_pred = my_model.predict(input_tensor)

        loss_value = K.sum( K.pow(K.abs((y_true - model_pred) / K.variable(E)), (_lambda / (_lambda + 1))) * K.pow((y_true - y_pred),2))
        return loss_value

    return my_loss

с использованием вышеуказанной функции в модели, такой как

K.clear_session()
input = Input(shape=(input_count, ))
...
model.compile(optimizer=self.config['TrainingAlgorithm'], loss = self.custom_loss_wrapper(input))

Проблема в вызове функции predict. input_tensor это просто заполнитель, он не имеет никакого значения при компиляции модели. Я также попробовал что-то вроде input_values = K.get_session().run(input_tensor) или K.get_value(input_tensor). Все это дает ошибку, потому что тензор пуст.

Все внутри my_loss должно быть написано с помощью функционального API Keras (K.sum() et c.).

Как написать model.predict() с функциональным API Keras?

...