Как использовать Curvefit в Python - PullRequest
0 голосов
/ 09 марта 2019

Я изучаю нелинейную кривую с питоном.
Я сделал пример, как показано ниже.
Но оптимизированный сюжет прорисован не хорошо

plt.plot(basketCont, fittedData)

Полагаю, оптимизированные параметры тоже не годятся.
Не могли бы вы дать некоторые рекомендации? Спасибо.

import matplotlib
matplotlib.use('Qt4Agg')
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
import numpy as np
from scipy.optimize import curve_fit 

def func(x, a, b, c):
    return a - b* np.exp(c * x) 

baskets = np.array([475, 108, 2, 38, 320])
scaling_factor = np.array([95.5, 57.7, 1.4, 21.9, 88.8])

popt,pcov = curve_fit(func, baskets, scaling_factor)

print (popt)
print (pcov)

basketCont=np.linspace(min(baskets),max(baskets),50)
fittedData=[func(x, *popt) for x in basketCont]

fig1 = plt.figure(1)

plt.scatter(baskets, scaling_factor, s=5)
plt.plot(basketCont, fittedData)

plt.grid()

plt.show()

1 Ответ

2 голосов
/ 09 марта 2019

Лично я не смог точно подогнать ваши данные, используя выложенное вами уравнение, однако сигмоидальное уравнение Хилла дало хорошее соответствие.Вот код Python для графического установщика, который я использовал.

plot

import numpy, scipy, matplotlib
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
import warnings


baskets = numpy.array([475.0, 108.0, 2.0, 38.0, 320.0])
scaling_factor = numpy.array([95.5, 57.7, 1.4, 21.9, 88.8])

# rename data for simpler code re-use later
xData = baskets
yData = scaling_factor


def func(x, a, b, c): # Hill sigmoidal equation from zunzun.com
    return  a * numpy.power(x, b) / (numpy.power(c, b) + numpy.power(x, b)) 


# these are the same as the scipy defaults
initialParameters = numpy.array([1.0, 1.0, 1.0])

# do not print unnecessary warnings during curve_fit()
warnings.filterwarnings("ignore")

# curve fit the test data
fittedParameters, pcov = curve_fit(func, xData, yData, initialParameters)

modelPredictions = func(xData, *fittedParameters) 

absError = modelPredictions - yData

SE = numpy.square(absError) # squared errors
MSE = numpy.mean(SE) # mean squared errors
RMSE = numpy.sqrt(MSE) # Root Mean Squared Error, RMSE
Rsquared = 1.0 - (numpy.var(absError) / numpy.var(yData))

print('Parameters:', fittedParameters)
print('RMSE:', RMSE)
print('R-squared:', Rsquared)

print()


##########################################################
# graphics output section
def ModelAndScatterPlot(graphWidth, graphHeight):
    f = plt.figure(figsize=(graphWidth/100.0, graphHeight/100.0), dpi=100)
    axes = f.add_subplot(111)

    # first the raw data as a scatter plot
    axes.plot(xData, yData,  'D')

    # create data for the fitted equation plot
    xModel = numpy.linspace(min(xData), max(xData))
    yModel = func(xModel, *fittedParameters)

    # now the model as a line plot
    axes.plot(xModel, yModel)

    axes.set_xlabel('X Data') # X axis data label
    axes.set_ylabel('Y Data') # Y axis data label

    plt.show()
    plt.close('all') # clean up after using pyplot

graphWidth = 800
graphHeight = 600
ModelAndScatterPlot(graphWidth, graphHeight)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...