Есть ли способ передать временные веса в функцию потерь? - PullRequest
0 голосов
/ 15 апреля 2020

Фон

В настоящее время я использую LSTM для выполнения регрессии. Я использую пакеты небольшого размера с достаточно большим количеством временных шагов (но намного, намного меньше, чем количество моих шагов).

Я пытаюсь перейти на большие партии с меньшим количеством временных шагов, но с с включенным состоянием, что позволяет использовать больший объем сгенерированных обучающих данных.

Однако в настоящее время я использую регуляризацию, основанную на sqrt (временном шаге), (это проверено на абляцию и помогает в скорости сходимости, это работает из-за статистической природы проблемы, ожидаемая ошибка уменьшается в 2 раза (временной шаг). Это выполняется с помощью tf.range для генерации списка правильного размера в функции потерь. Этот подход не будет правильным, если включен режим с сохранением состояния, поскольку он будет считать неправильное количество временных шагов (количество временных шагов в этом пакете, а не замечено до сих пор в целом).

Вопрос

Есть ли способ передать смещение или список целых чисел или чисел в функцию потерь? Желательно без изменения модели, но я признаю, что взлом такого рода может потребоваться.

Код

Упрощенная модель:

def create_model():    
    inputs = Input(shape=(None,input_nodes))
    next_input = inputs
    for i in range(dense_layers):
        dense = TimeDistributed(Dense(units=dense_nodes,
                activation='relu',
                kernel_regularizer=l2(regularization_weight),
                activity_regularizer=l2(regularization_weight)))\
            (next_input)
        next_input = TimeDistributed(Dropout(dropout_dense))(dense)

    for i in range(lstm_layers):
        prev_input = next_input
        next_input = LSTM(units=lstm_nodes,
                dropout=dropout_lstm,
                recurrent_dropout=dropout_lstm,
                kernel_regularizer=l2(regularization_weight),
                recurrent_regularizer=l2(regularization_weight),
                activity_regularizer=l2(regularization_weight),
                stateful=True,
                return_sequences=True)\
            (prev_input)
        next_input = add([prev_input, next_input])

    outputs = TimeDistributed(Dense(output_nodes,
            kernel_regularizer=l2(regularization_weight),
            activity_regularizer=l2(regularization_weight)))\
        (next_input)

    model = Model(inputs=inputs, outputs=outputs)

Функция потери

def loss_function(y_true, y_pred):
    length = K.shape(y_pred)[1]

    seq = K.ones(shape=(length,))
    if use_sqrt_loss_scaling:
        seq = tf.range(1, length+1, dtype='int32')
        seq = K.sqrt(tf.cast(seq, tf.float32))

    seq = K.reshape(seq, (-1, 1))

    if separate_theta_phi:
        angle_loss = phi_loss_weight * phi_metric(y_true, y_pred, angle_loss_fun)
        angle_loss += theta_loss_weight * theta_metric(y_true, y_pred, angle_loss_fun)
    else:
        angle_loss = angle_loss_weight * total_angle_metric(y_true, y_pred, angle_loss_fun)

    norm_loss = norm_loss_weight * norm_loss_fun(y_true, y_pred)
    energy_loss = energy_loss_weight * energy_metric(y_true, y_pred)
    stability_loss = stability_loss_weight * stab_loss_fun(y_true, y_pred)
    act_loss = act_loss_weight * act_loss_fun(y_true, y_pred)

    return K.sum(K.dot(0
        + angle_loss
        + norm_loss
        + energy_loss
        + stability_loss
        + act_loss
        , seq))

(Функции, которые вычисляют части функции потерь, не должны быть очень важными. Просто они также являются функциями потерь.)

1 Ответ

1 голос
/ 15 апреля 2020

Для этой цели вы можете использовать sample_weight аргумент метода fit и передать от sample_weight_mode='temporal' до compile метода, чтобы можно было присвоить вес каждому временному шагу каждого образца в пакете:

model.compile(..., sample_weight_mode='temporal')
model.fit(..., sample_weight=sample_weight)

sample_weight должен быть массивом формы (num_samples, num_timesteps).

Обратите внимание, что если вы используете генератор входных данных или экземпляр Sequence, вместо этого вам нужно передать образец весит как третий элемент сгенерированного кортежа / списка в генераторе или Sequence instance.

...