Как использовать tenorflow.distributions в пользовательской функции потерь для модели keras - PullRequest
0 голосов
/ 04 мая 2019

Для модели глубокого обучения, которую я определил с помощью tf2.0 keras, мне нужно написать собственную функцию потерь.

Поскольку это будет зависеть от таких вещей, как энтропия и нормальный log_prob, моя жизнь станет менее болезненной, если я смогу использовать tf.distributions.Normal и использовать две модели outpus соответственно mu и sigma.

Однако, как только я помещаю это в свою функцию потерь, я получаю ошибку Keras, что градиент для этой функции не определен.

ValueError: An operation has `None` for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.

Я пытался инкапсулировать вызов в tf.contrib.eager.Variable, когда я где-то читал. Не помогло.

Какой трюк их использовать? Я не вижу причины из фундаментальной архитектуры, почему я не мог бы использовать их в смешанной форме.

#this is just an example which does not really give a meaningful result.
import tensorflow as tf
import tensorflow.keras as K
import numpy as np

def custom_loss_fkt(extra_output):
    def loss(y_true,y_pred):
        dist = tf.distributions.Normal(loc=y_pred,scale=extra_output)
        d = dist.entropy()
        return K.backend.mean(d)
    return loss

input_node = K.layers.Input(shape=(1,))
dense = K.layers.Dense(8,activation='relu')(input_node)
#dense = K.layers.Dense(4,activation='relu')(dense)
out1 = K.layers.Dense(4,activation='linear')(dense)
out2 = K.layers.Dense(4,activation ='linear')(dense)
model = K.Model(inputs = input_node, outputs = [out1,out2])
model.compile(optimizer = 'adam', loss = [custom_loss_fkt(out2),custom_loss_fkt(out1)])
model.summary()
x = np.zeros((1,1))
y1 = np.array([[0.,0.1,0.2,0.3]])
y2 = np.array([[0.1,0.1,0.1,0.1]])
model.fit(x,[y1,y2],epochs=1000,verbose=0)
print(model.predict(x))
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...