Я занимаюсь машинным обучением, и мне приходится иметь дело с пользовательской функцией потерь.Производные и гессиан функции потерь сложно вывести, поэтому я прибегнул к их автоматическому вычислению с использованием Tensorflow.
Вот пример.
import numpy as np
import tensorflow as tf
y_true = np.array([
[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1]
], dtype=float)
y_pred = np.array([
[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1]
], dtype=float)
weights = np.array([1, 1, 1, 1, 1], dtype=float)
with tf.Session():
# We first convert the numpy arrays to Tensorflow tensors
y_true = tf.convert_to_tensor(y_true)
y_pred = tf.convert_to_tensor(y_pred)
weights = tf.convert_to_tensor(weights)
# The following code block is a custom loss
ys = tf.reduce_sum(y_true, axis=0)
y_true = y_true / ys
ln_p = tf.nn.log_softmax(y_pred)
wll = tf.reduce_sum(y_true * ln_p, axis=0)
loss = -tf.tensordot(weights, wll, axes=1)
grad = tf.gradients(loss, y_pred)[0]
hess = tf.hessians(loss, y_pred)[0]
hess = tf.diag_part(hess)
print(hess.eval())
, который печатает
[[0.24090069 0.12669198 0.12669198 0.12669198 0.12669198]
[0.12669198 0.24090069 0.12669198 0.12669198 0.12669198]
[0.12669198 0.12669198 0.12669198 0.24090069 0.12669198]
[0.12669198 0.12669198 0.24090069 0.12669198 0.12669198]
[0.04223066 0.04223066 0.04223066 0.04223066 0.08030023]
[0.04223066 0.04223066 0.04223066 0.04223066 0.08030023]
[0.04223066 0.04223066 0.04223066 0.04223066 0.08030023]]
Я доволен этим, потому что он работает, проблема в том, что он не масштабируется.Для моего случая использования мне нужна только диагональ матрицы Гессе.Мне удалось извлечь его, используя hess = tf.diag_part(hess)
, но он все равно вычислит полный гессиан, что не нужно.Накладные расходы настолько плохи, что я не могу использовать их для наборов данных среднего размера (~ 100 тыс. Строк).
Мой вопрос таков: Есть ли лучший способ извлечь диагональ гессиана ?Мне хорошо известно об этом сообщении и этом , но я не нахожу достаточно хороших ответов.