Как замаскировать мульти-выход в обучении Tensorflow 2 LSTM? - PullRequest
1 голос
/ 17 марта 2020

Я тренирую модель LSTM в Tensorflow 2, чтобы предсказать два выхода: поток и температуру воды.

  • Для некоторых временных шагов есть метка потока и метка температуры,
  • Для некоторых есть только метка потока или метка температуры,
  • , а для некоторых есть нет .

Таким образом, функция потерь должна игнорировать температуру и потерю потока, когда они не имеют метки. Я довольно много читал в документах по TF, но я изо всех сил пытаюсь понять, как лучше всего это сделать.

Пока я пытался

  • , указав sample_weight_mode='temporal' при компиляции модели и включении массива sample_weight numpy при вызове fit

Когда я это делаю, я получаю сообщение об ошибке с просьбой передать 2D-массив. Но это смущает меня, потому что есть 3 измерения: n_samples, sequence_length и n_outputs.

Вот некоторый код того, что я в основном пытаюсь сделать:

import tensorflow as tf
import numpy as np

# set up the model
simple_lstm_model = tf.keras.models.Sequential([
    tf.keras.layers.LSTM(8, return_sequences=True),
    tf.keras.layers.Dense(2)
])

simple_lstm_model.compile(optimizer='adam', loss='mae',
                          sample_weight_mode='temporal')

n_sample = 2
seq_len = 10
n_feat = 5
n_out = 2

# random in/out
x = np.random.randn(n_sample, seq_len, n_feat)
y_true = np.random.randn(n_sample, seq_len, n_out)

# set the initial mask as all ones (everything counts equally)
mask = np.ones([n_sample, seq_len, n_out])
# set the mask so that in the 0th sample, in the 3-8th time step
# the 1th variable is not counted in the loss function
mask[0, 3:8, 1] = 0

simple_lstm_model.fit(x, y_true, sample_weight=mask)

Ошибка:

ValueError: Found a sample_weight array with shape (2, 10, 2). In order to use timestep-wise sample weighting, you should
pass a 2D sample_weight array.

Есть идеи? Я не должен понимать, что делает sample_weights, потому что для меня это имеет смысл, только если массив sample_weight имеет те же размеры, что и выходные данные. Я мог бы написать собственную функцию потерь и обработать маскирование вручную, но, похоже, должно быть более общее или встроенное решение.

1 Ответ

2 голосов
/ 17 марта 2020

1. sample_weights

Да, вы понимаете это неправильно. В этом случае у вас есть 2 выборки, 10 временных шагов с 5 функциями каждый. Вы можете передать 2D тензор таким образом, чтобы каждый временной шаг для каждого образца по-разному вносил вклад в общую потерю, все функции были одинаково взвешены (как это обычно бывает).

Это не то, что вам нужно на все . Вы хотите замаскировать определенные значения потерь после их расчета, чтобы они не вносили вклад.

2. Пользовательские потери

Одним из возможных решений является реализация собственной функции потерь, которая умножает тензор потерь на маску перед тем, как взять mean или sum.

По сути, вы передаете mask и * 1021. * как-то соединяются вместе и разделяют его внутри функции для использования. Этого достаточно:

def my_loss_function(y_true_mask, y_pred):
    # Recover y and mask
    y_true, mask = tf.split(y_true_mask, 2)
    # You could user reduce_sum or other combinations
    return tf.math.reduce_mean(tf.math.abs(y_true - y_pred) * mask)

Теперь ваш код (без взвешивания, так как он не нужен):

simple_lstm_model = tf.keras.models.Sequential(
    [tf.keras.layers.LSTM(8, return_sequences=True), tf.keras.layers.Dense(2)]
)

simple_lstm_model.compile(optimizer="adam", loss=my_loss_function)

n_sample = 2
seq_len = 10
n_feat = 5
n_out = 2

x = np.random.randn(n_sample, seq_len, n_feat)
y_true = np.random.randn(n_sample, seq_len, n_out)

mask = np.ones([n_sample, seq_len, n_out])
mask[0, 3:8, 1] = 0

# Stack y and mask together
y_true_mask = np.stack([y_true, mask])

simple_lstm_model.fit(x, y_true_mask)

И так все работает. Вы также можете сложить значения другим способом, но я надеюсь, что вы почувствуете, как это можно сделать.

3. Маскировка выходов

Обратите внимание, что выше возникает несколько проблем. Если у вас много нулей и вы взяли mean, вы можете получить очень маленькое значение потери и помешать обучению. С другой стороны, если вы go с sum, он может взорваться.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...