Я использую большой набор данных для отслеживания моей модели lstm. Я хочу, чтобы весь процесс обучения непрерывно сохранял промежуточные веса и мог в любой момент остановиться и перезагрузить предварительно обученные веса, чтобы продолжить обучение на другом наборе данных. .
Я использовал trid fit_generator, но он мне не подходит.
Например:
Теперь у меня есть 10К данных, я разделил их на 10 маленьких данных, у каждого маленьких данных есть 1К.
В цикле for я читаю одну из небольших данных, тренируюсь на ней и сохраняю вес. В следующий раз я загружаю предварительно обученные веса, читаю еще одну небольшую информацию и тренирую ее. , .
#some fake code here
for pos, file_name in np.arange(0, 11000, 1000):
data = read(file_name)
modelpath = . . .
checkpoint = ModelCheckpoint(modelpath, period=15, . . .)
callbacks_list = [checkpoint]
initial_epoch=0
file_list = os.listdir(main_path + 'modles/')
if len(file_list) > 0:
epoch_list = get_file_list(main_path + 'modles/')
epoch_last = epoch_list[-1]
model.load_weights(main_path + 'modles/' + epoch_last)
print("checkpoint_loaded: ", epoch_last)
if epoch_last.split('-')[2] == '015' and epoch_last.split('-')[1] == file_name:
initial_epoch = 15
if epoch_last.split('-')[2] == '030' and epoch_last.split('-')[1] == file_name:
initial_epoch = 30
print('Begin from epoch: ', str(initial_epoch))
model.fit([data],
epochs=30,
batch_size=10,
validation_split=0.1,
callbacks=callbacks_list,
initial_epoch=initial_epoch
)
Я не знаю, работает ли приведенный выше код.
В орудии я вижу, что потери начинаются примерно с 2,5 и примерно до 1 из 30 эпох каждый раз, когда я читаю один из файлов.