Использовать пользовательскую функцию с пользовательскими параметрами в обратном вызове keras - PullRequest
0 голосов
/ 09 июля 2019

Я тренирую модель в Керасе и хочу строить графики результатов после каждой эпохи. Я знаю, что обратные вызовы keras предоставляют функцию «on_epoch_end», которая может быть перегружена, если кто-то хочет выполнить некоторые вычисления после каждой эпохи, но моя функция принимает некоторые дополнительные параметры, которые при выдаче сбрасывают код из-за ошибки мета-класса. Подробности приведены ниже:

Вот как я делаю это прямо сейчас, и это прекрасно работает: -

class NewCallback(Callback):

def on_epoch_end(self, epoch, logs={}):  #working fine, printing epoch after each epoch
    print("EPOCH IS: "+str(epoch))


epochs=5
batch_size = 16
model_saved=False
if model_saved:
    vae.load_weights(args.weights)
else:
    # train the autoencoder
    vae.fit(x_train,
            epochs=epochs,
            batch_size=batch_size,
            validation_data=(x_test, None),
           callbacks=[NewCallback()])

Но я хочу, чтобы моя функция обратного вызова была такой: -

class NewCallback(Callback,models,data,batch_size):
   def on_epoch_end(self, epoch, logs={}):
     print("EPOCH IS: "+str(epoch))
     x=models.predict(data)
     plt.plot(x)
     plt.savefig(epoch+".png")

Если я назову это так в форме:

callbacks=[NewCallback(models, data, batch_size=batch_size)]

Я получаю эту ошибку:

TypeError: metaclass conflict: the metaclass of a derived class must be a (non-strict) subclass of the metaclasses of all its bases 

Я ищу более простое решение для вызова моей функции или устранения этой ошибки в метаклассе, любая помощь будет высоко оценена!

1 Ответ

2 голосов
/ 09 июля 2019

Я думаю, что вы хотели бы определить класс, который происходит от обратного вызова и принимает модели, данные и т. Д. В качестве аргументов конструктора. Итак:

class NewCallback(Callback):
    """ NewCallback descends from Callback
    """
    def __init__(self, models, data, batch_size):
        """ Save params in constructor
        """
        self.models = models

    def on_epoch_end(self, epoch, logs={}):
        x = self.models.predict(self.data)
...