Я работаю над проектом, который требует некоторой работы в сети LSTM.У меня не было таких проблем раньше, и странным образом я ничего не изменил в этой части кода.
Проблема в том, что у меня есть обратный вызов, чтобы записать процесс обучения модели в файл, который называется Logger
.В методе on_train_end
я вызываю другую пользовательскую функцию для сохранения графиков loss
, acc
и perplexity
.Но параметр logs
метода on_train_end
задается пустым словарем, с этой стороны проблем с on_epoch_end
.
def on_train_end(self, logs={}):
#calculate total time
self.train_dur = time.time()-self.train_start
self.log.write("\n\nTotal duration: " + str(self.train_dur) + " seconds\n")
self.log.write("*"*100+"\n")
self.log.close()
print("train end logs:", logs)
self.__save_plots(logs)
#write time to a file
return
def on_epoch_end(self, epoch, logs={}):
#calculate epoch time
self.epoch_dur = time.time()-self.epoch_start
#write epoch logs to a file
print("epoch end logs:" , logs)
epoch_loss_info = "\nloss: {loss} -- val_loss: {val_loss}".format(loss = logs["loss"], val_loss = logs["val_loss"])
epoch_acc_info = "\nacc: {acc} -- val_acc: {val_acc}".format(acc = logs["acc"], val_acc = logs["val_acc"])
epoch_ppl_info = "\nppl: {ppl} -- val_ppl: {val_ppl}\n".format(ppl=logs["ppl"], val_ppl=logs["val_ppl"])
self.log.write("-"*100+"\n")
self.log.write("\n\nEpoch: {epoch} took {dur} seconds \n".format(epoch=epoch+1, dur=self.epoch_dur))
self.log.write(epoch_loss_info+epoch_acc_info+epoch_ppl_info)
#write generations to a file
generator = model_generator(gen_seq_len=self.gen_seq_len, by_dir=self.model_dir)
generated_text = generator.generate()
self.log.write("\nInput text:\t" + generated_text[:self.gen_seq_len] + "\n" )
self.log.write("\nGenerated text:\t" + generated_text + "\n")
self.log.write("-"*100+"\n")
return
нет. Как вы можете видеть ниже, у меня есть print
функция в каждом методе и print("epoch end logs")
выводит dict
, заполненный правильными значениями.Однако print("train end logs")
печатает пустую dict.
Я также пытался получить history
в качестве функции возврата из fit_generator
и пытался распечатать ее.Это также идет с ценностями.
Я искал GitHub
и Stackoverflow
, но не видел ничего подобного.
Заранее спасибо.