Пользовательская функция потерь Tensorflow 2.x в Google Colaboratory - PullRequest
0 голосов
/ 06 августа 2020

Проблема

Я пытаюсь написать собственную функцию потерь для моей модели Tensorflow 2. Я написал следующую функцию, которая вычисляет потерю, которую я ищу, когда я вручную передаю тензор ввода и вывода. , где 0 встречаются гораздо чаще. Таким образом, моя модель получала хорошее значение потерь, просто получая большую часть правильных 0, хотя там, где находятся единицы, является более важным metri c. Эта настраиваемая потеря делает более пропорциональный акцент на расположении единиц.

Я изменил model.compile(loss="binary_crossentropy") на model.compile(loss=on_off_balance_loss) в попытке использовать новую функцию потерь. Похоже, это не работает, поскольку функция потерь должна принимать весь пакет данных . Итак, я пробовал что-то вроде этого с model.compile(loss=on_off_balance_batch_loss):

def on_off_balance_batch_loss(y_true, y_pred) -> float:
    y_trues: list = tf.unstack(y_true)
    y_preds: list = tf.unstack(y_pred)
    loss: float = 0

    for i in range(0, len(y_trues)):
        loss = loss * (i / (i + 1)) + (on_off_balance_loss(y_trues[i], y_preds[i]) / (i + 1))

    return loss

Это не работает. Форма y_true - (None, None, None), форма y_pred - (None, X, Y), где X и Y - это размеры двумерного массива единиц и нулей.

I я работаю в Google Colaboratory. Однако локально np.asarray(), похоже, работает способом, который вызывает ошибку в Colaboratory. Итак, я не совсем уверен, кроется ли ошибка в моей функции потерь или в какой-то настройке в Colaboratory. Я убедился, что использую Tensorflow 2.3.0 как локально, так и в Colaboratory.

РЕДАКТИРОВАНИЯ:

Я пробовал добавить run_eagerly=True к model.compile() и использовать .numpy() вместо np.asarray() в on_off_balance_loss(). Это изменило тип ввода в on_off_balance_batch_loss с Tensor на EagerTensor. Это приводит к ошибке ValueError: No gradients provided for any variable: ['lstm_3/lstm_cell_3/kernel:0', 'lstm_3/lstm_cell_3/recurrent_kernel:0', 'lstm_3/lstm_cell_3/bias:0', 'dense_2/kernel:0', 'dense_2/bias:0', 'lstm_4/lstm_cell_4/kernel:0', 'lstm_4/lstm_cell_4/recurrent_kernel:0', 'lstm_4/lstm_cell_4/bias:0', 'dense_3/kernel:0', 'dense_3/bias:0', 'lstm_5/lstm_cell_5/kernel:0', 'lstm_5/lstm_cell_5/recurrent_kernel:0', 'lstm_5/lstm_cell_5/bias:0'].. Та же ошибка возникает, если я использую

def on_off_balance_batch_loss(y_true: EagerTensor, y_pred: EagerTensor) -> float:
    y_trues = tf.TensorArray(tf.float32, 1, dynamic_size=True, infer_shape=False).unstack(y_true)
    y_preds = tf.TensorArray(tf.float32, 1, dynamic_size=True, infer_shape=False).unstack(y_pred)

    loss: float = 0.0
    i: int = 0

    for tensor in range(y_trues.size()):
        elem_loss: float = on_off_balance_loss(y_trues.read(i), y_preds.read(i))
        loss = loss * (i / (i + 1)) + (elem_loss / (i + 1))
        i += 1

    return loss

и опускаю run_eagerly=True. Даже до того, как ошибки достигнуты, кажется, что вся программа работает медленнее, чем когда я использовал функцию потерь по умолчанию.

1 Ответ

0 голосов
/ 10 августа 2020

Что ж, оказывается, решение намного проще, чем то, что я пробовал выше. Ниже представлен стиль, в котором должна быть реализована такая функция. Я сравнил его с результатом моей исходной функции on_off_balance_loss, и он соответствует.

def on_off_equal_loss(y_true: Tensor, y_pred: Tensor) -> Tensor:
    on_delta: float = 0.99

    on_mask: Tensor = tf.greater_equal(y_true, on_delta)
    off_mask: Tensor = tf.less(y_true, on_delta)

    on_loss: Tensor = tf.divide(tf.reduce_sum(tf.abs(tf.subtract(
        y_true[on_mask], y_pred[on_mask]
    ))), tf.cast(tf.math.count_nonzero(on_mask), tf.float32))

    off_loss: Tensor = tf.divide(tf.reduce_sum(tf.abs(tf.subtract(
        y_true[off_mask], y_pred[off_mask]
    ))), tf.cast(tf.math.count_nonzero(off_mask), tf.float32))

    on_factor: float = 4.0
    return tf.divide(tf.add(tf.multiply(on_factor, on_loss), off_loss), on_factor + 1.0)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...