Пользовательская функция потерь в кератах, включающая огромное матричное умножение - PullRequest
0 голосов
/ 22 октября 2019

У меня проблемы с написанием пользовательской функции потерь в Керасе. У меня есть веса слоя "W" и матрица "M". Я хочу выполнить следующую трассировку операции ((W * M) * W '), чтобы вычислить мою функцию потерь. Трассировка является суммой диагональных элементов. В numpy я бы сделал следующее:

np.trace(np.dot(np.dot(W,M),W.T))) or 

def custom_regularizer(W,M):
    sum_reg = 0
    for i in range(W.shape[1]):
        for j in range(i,W.shape[1]):
            vector = W[:,i] - W[:,j]
            sum_reg = sum_reg + M[i,j] * (LA.norm(vector)**2)
    return sum_reg

Для керас я написал следующую функцию потерь

def custom_loss(W):

  def lossFunction(y_true,y_pred):    
    loss = tf.trace(K.dot(K.dot(W,K.constant(M)),K.transpose(W)))
    return loss

return lossFunction

Проблема в том, что керас вычисляет всю внешнюю матрицу, размерность которой200000 * 200000, что дает ошибку памяти. Есть ли способ, с помощью которого я могу просто получить сумму диагональных элементов, не делая вычисления всей матрицы.

Как сделать так же, как функцию потери кераса?

1 Ответ

1 голос
/ 22 октября 2019

Если вы выполните некоторые хитрые уловки, чтобы вычислить трассировку, у вас не должно быть нехватки памяти. Например, вы можете сослаться на this .

...