Почему функция SciPy curve_fit заботится о типе xdata? - PullRequest
0 голосов
/ 17 декабря 2018

Я пытался уместить некоторые данные, используя SciPy curve_fit, и получил действительно странные результаты.Поэтому я попробовал и попробовал и проверил и нашел проблему в типе xdata.Когда xdata имеет тип int, результаты становятся очень странными.Но это не относится ко всем функциям f.Я тестировал с полиномами до 6-го порядка. Начиная с 3-го порядка результаты стали странными.

Минимальный пример:

import numpy as np
from scipy.optimize import curve_fit

def poly4(x, a, b, c, d, e):
    return a*np.power(x,4) + b*np.power(x,3) + c*np.power(x,2) + d*x + e

x = np.linspace(0, 9.6, 2400)
y = poly4(x, 0.03, -0.68, 5.6, -22, 1351)

x1 = np.arange(0, 2400, 1, dtype=np.dtype('float'))
x2 = np.arange(0, 2400, 1, dtype=np.dtype('int'))

popt1,_ = curve_fit(poly4, x1, y)
popt2,_ = curve_fit(poly4, x2, y)

f1 = poly4(x1, *popt1)
f2 = poly4(x2, *popt2)

Построение этих значений с помощью

import matplotlib.pyplot as plt
plt.plot(f1, label='f1, float range')
plt.plot(f2, label='f2, int range')
plt.legend()
plt.show()

дает

curve_fit plot with int and float range

Синяя линия - именно то, как должен выглядеть результат.Глядя на вывод curve_fit с

print(popt1)
print(popt2)

дает

[9.05733149e-12 -4.92513534e-08 9.73032914e-05 -9.17048770e-02 1.35100000e + 03]

[3.52993170e-11 -1.52725549e-10 9.38577666e-06 -3.58806105e-02 1.34272489e + 03]

Почему эти результаты такие разные?Ну, очевидно, из-за типа данных xdata.Но почему curve_fit должен заботиться о типе данных xdata?Я не вижу причины этого и не нашел никакой документации по этому поводу.

Редактировать: Проверено на python 3.6.3 с scipy 0.19.1 и python 3.7.1 с scipy 1.1.0.Оба в Windows.

Ответы [ 2 ]

0 голосов
/ 17 декабря 2018

Тип x заботит не curve_fit, а ваша функция poly4.Numpy сохраняет тип массивов в своих операциях.Поскольку вы берете n-степень целого числа, вы быстро столкнетесь с целочисленным переполнением, что приведет к неожиданным результатам.

См., Например, вывод np.power (x, 3):

x = np.arange(0,2400,1, dtype=np.int32)
plt.plot(x,np.power(x,3))

enter image description here

0 голосов
/ 17 декабря 2018

Проблема, с которой вы и все, кто не может воспроизвести вашу проблему, заключается в том, что размер np.dtype('int') различен на разных платформах.Если вы замените свои декларации x1 и x2 на:

x1 = np.arange(0, 2400, 1, dtype=np.dtype('float'))
x2 = np.arange(0, 2400, 1, dtype=np.int32)

, то вы сможете последовательно воспроизводить странный вывод независимо от платформы:

enter image description here

Первоначальная проблема вызвана тем фактом, что np.int32 слишком мал, чтобы иметь дело с некоторыми из очень больших чисел, которые вы вычисляете, и значения промежуточного вычисления переполняются.Таким образом, результат:

poly4(np.arange(2000, 2010, dtype=np.int32), 0.03, -0.68, 5.6, -22, 1351)
# array([4.60917546e+08, 3.82703937e+08, 4.34772636e+08, 3.59427040e+08,
   4.14366625e+08, 3.41894792e+08, 3.99711018e+08, 3.30118704e+08,
   3.90817330e+08, 3.24110298e+08])

сильно отличается от результата:

poly4(np.arange(2000, 2010, dtype=np.int64), 0.03, -0.68, 5.6, -22, 1351)
# array([4.74582357e+11, 4.75534936e+11, 4.76488948e+11, 4.77444394e+11,
   4.78401277e+11, 4.79359597e+11, 4.80319357e+11, 4.81280557e+11,
   4.82243198e+11, 4.83207283e+11])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...