Причиной переполнения Python является произведение многих вероятностей, обращения большой ковариационной матрицы и вычисления ее определителя. - PullRequest
0 голосов
/ 10 ноября 2019

Итак, у меня есть такая функция:

def logpp(X,m,S):
# Find the number of dimensions from the data vector
d = X.shape[1]

# Invert the covariance matrix
Sinv = np.linalg.inv(S)

# Compute the quadratic terms for all data points
Q = -0.5*(np.dot(X-m,Sinv)*(X-m)).sum(axis=1)

# Raise them quadratic terms to the exponential
Q = np.exp(Q)

# Divide by the terms in the denominator
P = Q / np.sqrt((2*np.pi)**d * np.linalg.det(S))

# Take the product of the probability of each data points
Pprod = np.prod(P)

# Return the log-probability
return np.log(Pprod)

Когда я генерирую больший ввод, результат будет переполнен. Как переписать порядок, чтобы избежать переполнения?

Моя функция ввода:

X1 = numpy.random.mtrand.RandomState(123).normal(0,1,[5,len(m1)])
X2 = numpy.random.mtrand.RandomState(123).normal(0,1,[20,len(m2)])
X3 = numpy.random.mtrand.RandomState(123).normal(0,1,[100,len(m3)])

Ответы [ 2 ]

0 голосов
/ 11 ноября 2019

Без использования функции «slove»:

def logp_robust(X,m,S):

# Find the number of dimensions from the data vector
d = X.shape[1]
N = X.shape[0]
# Invert the covariance matrix
Sinv = numpy.linalg.inv(S)

# Compute the quadratic terms for all data points
Q = numpy.sum(-0.5*(numpy.dot(X-m,Sinv)*(X-m)).sum(axis=1))


return (Q-0.5 * (d * N)* numpy.log(2*numpy.pi)-0.5* N* numpy.log(numpy.linalg.det(S)))
0 голосов
/ 10 ноября 2019

пара указателей:

  • обратите внимание, что при работе с вероятностями вы обычно хотите оставаться в «пространстве журнала», поскольку все неотрицательно и код имеет тенденцию использовать значения, которые находятся ниже / переполнены плавающимчисла точек легко
  • библиотеки линейной алгебры имеют множество инструментов для работы с плохо обусловленными и другими числовыми нестабильностями, ознакомьтесь с учебником linalg для нескольких указателей

в вашем случае я бы переписал эту функцию так:

import numpy as np
from scipy import linalg

def logpp(X, m, S):
    X = X - m
    Q = -0.5 * (linalg.solve(S, X.T).T * X).sum(axis=1)

    sgn, logdetS = np.linalg.slogdet(S)
    assert sgn > 0

    logP = Q - 0.5 * (np.log(2*np.pi)*d + logdetS)

    return np.sum(logP)

, что лучше для меня

Я не совсем уверен в правильности использования solve здесь, но я думаю, что это правильно. если бы кто-то мог прокомментировать это, было бы здорово!

...