Пользовательский обратный вызов keras для обновления градиентов каждые n партий - PullRequest
0 голосов
/ 28 января 2020

Вот код для пользовательского обратного вызова, но он не работает:

class subdivision(callbacks.Callback):
    def __init__(self,subdiv):
        super(subdivision,self).__init__()
        self.subdiv = subdiv
        self.weights = None
        self.index = 1
        self.gradient = None

    def on_train_begin(self,logs=None):
        self.weights = self.model.get_weights()

    def on_batch_begin(self,batch,logs=None):
        self.model.set_weights(self.weights)

    def on_batch_end(self,batch,logs=None):
        if (self.index % self.subdiv==0):
            self.model.set_weights(self.weights+self.gradient)
            self.weights = self.model.get_weights()
            self.gradient=None
        else:
            if self.gradient==None:
                self.gradient = k.gradients(self.model.output,self.model.weights)
            else:
                self.gradient += k.gradients(self.model.output,self.model.weights)
        self.index+=1
    def on_epoch_end(self,epoch,logs=None):
        self.index = 1 

Я не могу добавить вес модели с градиентами. Я также не могу прочитать значения из тензора градиента.

...