KERAS: пользовательская функция потери от логики - PullRequest
0 голосов
/ 01 марта 2019

Я пытаюсь реализовать пользовательскую функцию потерь согласно этой бумаге , уравнение 3.8 стр. 19. Я пришел к этой реализации:

import numpy as np
import keras.backend as T
from keras import models

def costume_loss(censored):
    '''costume loss function'''

    def loglikelihood_function(y_true,y_pred,censored):
        '''implements likelihood function as per equation 3.8 in using survival prediction techniques to learn consumer'''
        results=[]
        n=T.int_shape(y_pred)[0]  #number of instances
        K=T.int_shape(y_pred)[1]    #number of subintervals

        print(T.int_shape(y_pred)[0],T.int_shape(y_pred)[1])


        for i in range(n):
            if censored[i]==0:
                sum1=np.sum([y_pred[i][j]*y_true[i][j] for j in range(K)])
                sum2=np.sum([math.exp(np.sum([ y_pred[i][k]  for k in range(j,K)])) for j in range(K)])
                results.append(-(sum1-math.log(sum2)))      
            else:
                sum1=np.sum([y_true[i][j]*math.exp(np.sum([y_pred[i][k] for k in range(j,K)]))   for j in range(K)])
                sum2=np.sum([math.exp(np.sum([ y_pred[i][k]  for k in range(j,K)])) for j in range(K)])
            results.append(-(math.log(sum1)-math.log(sum2))) 

        x=tf.constant(np.array(results,dtype='float64'))   #convert into tensor
        return T.mean(x)

    def loss_function(y_true,y_pred):
        return loglikelihood_function(y_true,y_pred,censored)

    return loss_function

Однако, когда я пытаюсь скомпилироватьмодель:

model.compile(optimizer='rmsprop',loss=costume_loss(censored=c),metrics=['accuracy'])

Я получаю

TypeError: 'NoneType' object cannot be interpreted as an integer

Похоже, что размер партии не определен во время расчета.Может ли кто-нибудь указать мне правильное направление?Может быть, мне нужно реализовать это с помощью тензорных операций?Если да, то как?

Спасибо

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...