Диагональ гессиана с тензорным потоком - PullRequest
0 голосов
/ 19 ноября 2018

Я занимаюсь машинным обучением, и мне приходится иметь дело с пользовательской функцией потерь.Производные и гессиан функции потерь сложно вывести, поэтому я прибегнул к их автоматическому вычислению с использованием Tensorflow.

Вот пример.

import numpy as np
import tensorflow as tf

y_true = np.array([
    [1, 0, 0, 0, 0],
    [0, 1, 0, 0, 0],
    [0, 0, 0, 1, 0],
    [0, 0, 1, 0, 0],
    [0, 0, 0, 0, 1],
    [0, 0, 0, 0, 1],
    [0, 0, 0, 0, 1]
], dtype=float)

y_pred = np.array([
    [1, 0, 0, 0, 0],
    [0, 1, 0, 0, 0],
    [0, 0, 0, 1, 0],
    [0, 0, 1, 0, 0],
    [0, 0, 0, 0, 1],
    [0, 0, 0, 0, 1],
    [0, 0, 0, 0, 1]
], dtype=float)

weights = np.array([1, 1, 1, 1, 1], dtype=float)

with tf.Session():

    # We first convert the numpy arrays to Tensorflow tensors
    y_true = tf.convert_to_tensor(y_true)
    y_pred = tf.convert_to_tensor(y_pred)
    weights = tf.convert_to_tensor(weights)

    # The following code block is a custom loss 
    ys = tf.reduce_sum(y_true, axis=0)
    y_true = y_true / ys
    ln_p = tf.nn.log_softmax(y_pred)
    wll = tf.reduce_sum(y_true * ln_p, axis=0)
    loss = -tf.tensordot(weights, wll, axes=1)

    grad = tf.gradients(loss, y_pred)[0]

    hess = tf.hessians(loss, y_pred)[0]
    hess = tf.diag_part(hess)

    print(hess.eval())

, который печатает

[[0.24090069 0.12669198 0.12669198 0.12669198 0.12669198]
 [0.12669198 0.24090069 0.12669198 0.12669198 0.12669198]
 [0.12669198 0.12669198 0.12669198 0.24090069 0.12669198]
 [0.12669198 0.12669198 0.24090069 0.12669198 0.12669198]
 [0.04223066 0.04223066 0.04223066 0.04223066 0.08030023]
 [0.04223066 0.04223066 0.04223066 0.04223066 0.08030023]
 [0.04223066 0.04223066 0.04223066 0.04223066 0.08030023]]

Я доволен этим, потому что он работает, проблема в том, что он не масштабируется.Для моего случая использования мне нужна только диагональ матрицы Гессе.Мне удалось извлечь его, используя hess = tf.diag_part(hess), но он все равно вычислит полный гессиан, что не нужно.Накладные расходы настолько плохи, что я не могу использовать их для наборов данных среднего размера (~ 100 тыс. Строк).

Мой вопрос таков: Есть ли лучший способ извлечь диагональ гессиана ?Мне хорошо известно об этом сообщении и этом , но я не нахожу достаточно хороших ответов.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...