Scipy curve_fit путаница с использованием границ и начальных параметров на простых данных - PullRequest
0 голосов
/ 13 марта 2020

, хотя я отлично подошел для других наборов данных, по какой-то причине следующий код не работает для относительно простого набора точек. Я пробовал как экспоненциальную, так и убывающую экспоненты, а также начальные параметры и границы. Я считаю, что это разоблачает мое глубокое недопонимание; Я ценю любые советы.

    snr = [1e10, 5, 1, .5, .1, .05]
    tau = [1, 8, 10, 14, 35, 80]

    fig1, ax1 = plt.subplots()

    def fit(x, a, b, c): #c: asymptote
        #return a * np.exp(b * x) + 1.
        return np.power(x,a)*b + c

    xlist = np.arange(0,len(snr),1)
    p0 = [-1., 1., 1.]
    params = curve_fit(fit, xlist, tau, p0)#, bounds=([-np.inf, 0., 0.], [0., np.inf, np.inf]))

    a, b, c = params[0]
    print(a,b,c)
    ax1.plot(xlist, fit(xlist, a, b, c), c='b', label='Fit')

    #ax1.plot(snr, tau, zorder=-1, c='k', alpha=.25)
    ax1.scatter(snr, tau)
    ax1.set_xscale('log')        
    #ax1.set_xlim(.02, 15)
    plt.show()

Обновление 1: справочный рисунок, следующий за кодом Eri c M: Figure Будет комментировать в сообщение ниже.


Исправление для обновления 1: xlist = np.arange(0.01,10000,1)/1000+0.01

Ответы [ 2 ]

2 голосов
/ 13 марта 2020

Это сработало для меня. Была пара вопросов. Включая мой комментарий. В вашем xlist также есть ошибка «делить на ноль», поэтому я избежал этого, добавив 0,01 к xlist и увеличив плотность точек, чтобы кривая была округлена.

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

snr = [1e10, 5, 1, .5, .1, .05]
tau = [1, 8, 10, 14, 35, 80]

fig1, ax1 = plt.subplots()

def fit(x, a, b, c):
    return np.power(x, a)*b + c

xlist = np.arange(0.01,10000,1)/1000+0.01
xlist = np.append(xlist, 1e10)
p0 = [-10, 10., 1.]
params = curve_fit(fit, snr, tau, p0)

print('Fitting parameters: {}'.format(params[0]))
ax1.plot(xlist, fit(xlist, *params[0]), c='b', label='Fit')
ax1.scatter(snr, tau)
ax1.set_xscale('log')        
plt.show()

enter image description here

1 голос
/ 13 марта 2020
import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import curve_fit


def fit(x, a, b, c):
    return np.power(x, a)*b + c


x = [1e10, 5, 1, .5, .1, .05]
y = [1, 8, 10, 14, 35, 80]

popt, pcov=curve_fit(fit,x,y, bounds=([-np.inf, 0., 0.], [0., np.inf, np.inf]))
x_curve = np.append(np.linspace(0.01, 10, 1000), 1e11)

# plot
fig, ax = plt.subplots()
ax.set_ylim(-25,100)
ax.set_xscale("log")
ax.scatter(x, y)
plt.plot(x_curve, np.power(x_curve, popt[0])*popt[1] + popt[2], color = 'green')
plt.show()

Выход:

enter image description here

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...