Как использовать Keras `add_loss` исключительно - PullRequest
2 голосов
/ 26 мая 2019

Я реализую модель со сложным термином потерь, для которого требуется функция Keras add_loss. Я хочу реализовать свою модель как tf.keras.Model (см. документация ). К сожалению, кажется, что Keras не может справиться со случаем, когда используется add_loss, но compile не получает спецификацию функции потерь.

Вот минимальный пример, который «работает» (ничего интересного не делает, но ошибок не возникает):

import tensorflow as tf
import numpy as np    

class Model(tf.keras.Model):

    def __init__(self, *args, **kwargs):
        super(Model, self).__init__(*args, **kwargs)    
        self.layer = tf.keras.layers.Dense(3, 
                       activation=tf.keras.activations.relu)

    def call(self, inputs, **kwargs):    
        output = self.layer(inputs)
        return output


if __name__ == '__main__':
    tf.random.set_random_seed(1)
    m = Model()    
    x = np.array([[1., 2., 3.]])
    y = np.array([[0., 1., 0.5]])    
    m.compile(tf.keras.optimizers.SGD(), 
              loss=tf.keras.losses.mean_squared_error)
    m.fit(x, y, epochs=100, verbose=0)
    print(m.predict(x))
    # [[0.         0.99999666 0.5000101 ]]

Однако мне не нужна «внешняя» функция потерь, только одна, указанная с помощью add_loss, и это не работает. Я надеялся, что следующий код делает то же самое, что и приведенный выше:

# Same imports...

class Model(tf.keras.Model):

    # def __init__ ...

    def call(self, inputs, **kwargs):    
        output = self.layer(inputs)
        error = tf.keras.losses.mean_squared_error([0., 1., 0.5], output)
        self.add_loss(error)
        return output


if __name__ == '__main__':    
    tf.random.set_random_seed(1)
    m = Model()    
    x = np.array([[1., 2., 3.]])    
    m.compile(tf.keras.optimizers.SGD())
    m.fit(x, epochs=100, verbose=0)
    print(m.predict(x))

но это не работает. В частности, с m.fit(x) я получаю сообщение об ошибке

AttributeError: 'Model' object has no attribute 'total_loss'

Один из возможных обходных путей - добавить функцию тривиальных потерь

def no_loss(*args):
    return tf.keras.backend.constant(0.0)

но я надеялся на более элегантное решение.

Я использую Tensorflow 1.13.1.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...