Я пытался построить модель линейной регрессии в python, используя scikit learn и matplotlib. Однако код запутался, когда я строил данные с помощью plt.scatter () и plt.plot ()
Вот мой код, который моделирует данные с использованием sklearn: -
from sklearn import linear_model
regr = linear_model.LinearRegression()
train_x = np.asanyarray(train[['ENGINESIZE']])
train_y = np.asanyarray(train[['CO2EMISSIONS']])
regr.fit (train_x, train_y)
# The coefficients
print ('Coefficients: ', regr.coef_)
print ('Intercept: ',regr.intercept_)
Вот мой код, который отображает модель линейной регрессии на графике: -
plt.scatter(train.ENGINESIZE, train.CO2EMISSIONS, color='blue')
plt.plot(train_x, regr.coef_[0][0]*train_x + regr.intercept_[0], '-y')
plt.xlabel("Engine size")
plt.ylabel("Emission")
Я не понимаю аргументы, переданные в plt.scatter()
и plt.plot()
. Я заметил, что при удалении метода plt.plot()
линия наилучшего соответствия не отображается на графике.