Как работать со значениями nan при использовании multinomial.pmf из scipy.stats? - PullRequest
1 голос
/ 15 апреля 2020

Запустив небольшой эксперимент, я заметил, что если сумма параметров, передаваемых в multinomial.pmf, даже немного выше 1, то возвращаемое значение равно nan.

См. Пример ниже:

import numpy as np
from scipy.stats import multinomial as multi_s


def safe_multi(x, params):
    params_sum = params.sum()
    safe_params = params / params_sum if params_sum > 1 else params
    return multi_s.pmf(x, sum(x), safe_params)


params1 = np.array(
    [0.21310660657549002, 0.21310660657549002, 0.21310660657549002,
     2.8699968847179538e-06, 0.0023286820110742764, 2.8699968847179538e-06,
     0.0023286820110742764, 2.8699968847179538e-06, 2.8699968847179538e-06,
     0.0023258120141895593, 0.0016006555371205703, 0.0023258120141895593,
     0.0016006555371205703, 0.04333851102588555, 0.04333851102588555,
     0.04333851102588555, 0.04333851102588555, 0.04333851102588555,
     0.04333851102588555, 0.04333851102588555, 0.04333851102588555,
     0.0007251564770689873, 0.0007251564770689873, 7.377915317967555e-27,
     7.377915317967555e-27])

params2 = np.array(
    [0.3333333333333332, 0.3333333333333332, 0.3333333333333332,
     2.931077467598623e-93, 6.532951191692606e-25, 1.4080539652716124e-224,
     6.532951191692606e-25, 1.4080539652716124e-224, 1.4080539652716124e-224,
     6.532951191692606e-25, 6.532951191692606e-25, 6.532951191692606e-25,
     6.532951191692606e-25, 3.5127398105835854e-17, 3.5127398105835854e-17,
     3.5127398105835854e-17, 3.5127398105835854e-17, 3.5127398105835854e-17,
     3.5127398105835854e-17, 3.5127398105835854e-17, 3.5127398105835854e-17,
     7.860388790608641e-191, 7.860388790608641e-191, 2.931077467598623e-93,
     2.931077467598623e-93])

samples = np.array(
    [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])

result1_scipy = multi_s.pmf(samples, samples.sum(), params1)
result2_scipy = multi_s.pmf(samples, samples.sum(), params2)
print(result1_scipy, result2_scipy)
print(params1.sum(), params2.sum())

print('-----------------')

result1_scipy_sum = multi_s.pmf(samples, samples.sum(), params1 / params1.sum())
result2_scipy_sum = multi_s.pmf(samples, samples.sum(), params2 / params2.sum())
print(result1_scipy_sum, result2_scipy_sum)
print((params1 / params1.sum()).sum(), (params2 / params2.sum()).sum())

print('-----------------')

result1 = safe_multi(samples, params1)
result2 = safe_multi(samples, params2)
print(result1, result2)

С выводом:

nan 0.22222222222222202
1.0000000000000002 0.9999999999999999
-----------------
0.058068684987554825 nan
0.9999999999999998 1.0000000000000002
-----------------
0.058068684987554825 0.22222222222222202

Есть ли лучший способ справиться с переполнением чисел, которое может произойти из параметров? Моя оболочка safe_multi (), кажется, делает свое дело, но меня интересуют лучшие практики борьбы с этим.

РЕДАКТИРОВАТЬ: я нашел пример, показанный ниже, который, кажется, всегда возвращает nan, несмотря на суммирование параметров 1. 1. 1011 *

c = np.array(
    [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])

params = np.array(
    [0.02702702702702703, 0.02702702702702703, 0.0, 0.0, 0.0,
     0.04054054054054054, 0.0, 0.0, 0.0, 0.0, 0.06756756756756757,
     0.0945945945945946, 0.06756756756756757, 0.06756756756756757,
     0.04054054054054054, 0.04054054054054054, 0.06756756756756757,
     0.06756756756756757, 0.013513513513513514, 0.013513513513513514, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.04054054054054054, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.013513513513513514, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.04054054054054054, 0.013513513513513514,
     0.013513513513513514, 0.013513513513513514, 0.013513513513513514, 0.0, 0.0,
     0.0, 0.0, 0.013513513513513514, 0.013513513513513514, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.02702702702702703, 0.02702702702702703, 0.013513513513513514,
     0.013513513513513514, 0.0, 0.0, 0.0, 0.0, 0.0, 0.013513513513513514, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.013513513513513514, 0.013513513513513514, 0.013513513513513514,
     0.013513513513513514, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.013513513513513514,
     0.0, 0.0, 0.0, 0.0, 0.013513513513513514, 0.013513513513513514, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])

print(multi_s.pmf(c, c.sum(), params))
print(multi_s.pmf(c, c.sum(), params / params.sum()))
print(params.sum())

Вывод:

nan
nan
1.0

После анализа кода scipy кажется, что преступник находится на линии 3012 _multivariate.py:

p[..., -1] = 1. - p[..., :-1].sum(axis=-1)

Эта строка гарантирует, что параметры суммируются с 1, устанавливая последний параметр на соответствующее значение. В приведенном выше примере это добавляет очень маленькое отрицательное значение для последнего параметра, который затем помечается как проблема в будущем. Чтобы убедиться, что это условие выполнено, нельзя ли разделить на сумму параметров в отличие от выполнения вычитания?

...