keras возобновляет тренировку fit_generator (модель LSTM), теряя значение последней потери - PullRequest
0 голосов
/ 09 мая 2019

Может ли кто-нибудь помочь мне найти способ возобновить обучение модели LSTM с помощью fit_generator, не возвращая значение потерь в inf?

Справочная информация: Я обучаю модель LSTM, и у меня есть очень большие данные временного ряда (много раз выборки) и только 2 функции.Поэтому форма данных моего временного ряда x равна N 2, где N - очень большое число.Я использую генератор пакетов, чтобы случайным образом сегментировать мои данные в меньшие партии batch_N на 2 (где batch_N намного меньше, чем N):

def batch_generator(batch_size, sequence_length): 
...
    for i in range(batch_size):
        ...
       x_batch[i] = batch_x_train_scaled[idx:idx+sequence_length]
       y_batch[i] = batch_y_train_scaled[idx:idx+sequence_length] 
 yield (x_batch, y_batch)

Я также использую ModelCheckpoint для сохранения наиболее обученной модели

callback_checkpoint = ModelCheckpoint(filepath=path_checkpoint, 
monitor='val_loss', verbose=1,save_weights_only=False, save_best_only=True)

Кроме того, каждый раз, когда я хочу возобновить обучение, сначала загружаю последнюю сохраненную модель:

if True:
    try:
#         model.load_weights(path_checkpoint)
        model = load_model(path_checkpoint)

    except Exception as error:
        print("Error trying to load checkpoint.")
        print(error)

В чем проблема? Каждый раз, когда я возобновляюПри обучении новый пакетный файл сначала загружается в fit_generator, и модель будет использовать последние веса сохранения.значение потерь сбрасывается на инф.Следовательно, в конце первой возобновленной эпохи обучения, независимо от того, насколько хороши или плохи результаты обучения, модель сообщает, что val_loss улучшен с inf до некоторого числа и, таким образом, перезаписывает новые весовые коэффициенты.Проблема в том, что иногда новые веса не так оптимальны, как предыдущие (из-за того, что на этот раз модель использует новые данные партии для обучения), и поэтому я потеряю некоторые оптимальные веса.

Что я уже сделал, чтобы решить эту проблему?

Первый подход (неудачный) : определение пользовательской функции потерь:

def my_loss(y_true, y_pred):
    train_loss = binary_crossentropy(y_true, y_pred)
    validation_loss = 2*binary_crossentropy(y_true, y_pred)
    temp=tf.keras.backend.cast(validation_loss,'float16')
    if temp>1:  # update 1 to last best val_loss before resume training
        validation_loss=validation_loss+np.inf
# validation_loss=np.inf
    return tf.keras.backend.in_train_phase(train_loss, validation_loss)

model.compile(loss=my_loss, optimizer=optimizer)

результаты первого подхода:

Error:
---> 12     if temp>1:
TypeError: Using a `tf.Tensor` as a Python `bool` is not allowed. Use `if t is not None:` instead of `if t:` to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.

Второй подход (неудачный) : определение пользовательского обратного вызова для сохранения модели:

best_val_loss = 1 # update 1 to last best val_loss before resume training

def saveModel(epoch,logs):
    val_loss = logs['val_loss']
    if val_loss < best_val_loss:
        best_val_loss=val_loss
        model.save('my_model.hdf5')

my_callback = LambdaCallback(on_epoch_end=saveModel)

результаты второго подхода:

UnboundLocalError: local variable 'best_val_loss' referenced before assignment

третий подход (неудачный) : определение пользовательского обратного вызова для сохранения модели:

best_val_loss = 1 # update 1 to last best val_loss before resume training

def saveModel(epoch,logs,best_val_loss):
    val_loss = logs['val_loss']
    if val_loss < best_val_loss:
        best_val_loss=val_loss
        model.save('my_model.hdf5')

my_callback = LambdaCallback(on_epoch_end=saveModel)

результатыподхода три:

TypeError: saveModel() missing 1 required positional argument: 'best_val_loss'
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...