Как отобразить мою линию полиномиальной регрессии? - PullRequest
1 голос
/ 03 октября 2019

У моего сюжета очень жирная линия, которую я не ожидал и не смог устранить самостоятельно. Я не знаю, как показать изображение.

Выполнение EDA для набора данных Craigslist Kaggle Auto. Я хочу отобразить, а затем сравнить и сопоставить соответствие линейной и полиномиальной регрессии, соотнося цену и год выпуска модели для каждой уникальной марки автомобиля и модели (например, Ford F150).

Как сделать следующий график с более нормальнымсмотря линия, ширина линии ничего не меняет.

enter image description here

df_f150=df[df['Make and Model']=='ford F-150']

#plotting a linear regression line for each dataframe
fig = plt.figure(figsize=(10,7))
sns.regplot(x=df_f150.year, y=df_f150.price, color='b')


'#Here is where I try to do one of the polynomial regressions'

# Legend, title and labels.
#plt.legend(labels=x)
plt.title('Relationship Between Model Year and Price', size=24)
plt.xlabel('Year', size=18)
plt.ylabel('Price', size=18)
plt.xlim(1990,2020)
plt.ylim(1000,100000)

from sklearn.preprocessing import PolynomialFeatures 


X = df_f150['year'].values.reshape(-1,1)
y = df_f150['price'].values.reshape(-1,1)

poly = PolynomialFeatures(degree = 8) 
poly.fit_transform(X) 

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

regressor = LinearRegression()  
regressor.fit(X_train, y_train) #training the algorithm

#To retrieve the intercept:
print(regressor.intercept_)
#For retrieving the slope:
print(regressor.coef_)

y_pred = regressor.predict(X_test)

dfres = pd.DataFrame({'Actual': y_test.flatten(), 'Predicted': y_pred.flatten()})
dfres

plt.scatter(X_test, y_test,  color='gray')
plt.plot(X_test, y_pred, color='red', linewidth=2)
plt.show()

1 Ответ

1 голос
/ 05 октября 2019

Во-первых, всегда чистите и проверяйте данные:

  • Данные даны от Kaggle: списки автомобилей от Craigslist.org
  • Кстати, сюжет генерируетсяsns.regplot почти совпадает с генерируемым при выполнении регрессии с sklearn. Поэтому я не включил дополнительный код.

Загрузка и выбор данных:

from pathlib import Path
import pandas as pd


file = Path.cwd() / 'data/craigslist-carstrucks-data/craigslistVehicles.csv'

df = pd.read_csv(file, usecols=['price', 'year', 'manufacturer', 'make'])

 price    year manufacturer          make
  3500  2006.0    chevrolet           NaN
  3399  2002.0        lexus         es300
  9000  2009.0    chevrolet  suburban lt2
 31999  2012.0          ram          2500
 16990  2003.0          ram          3500

# Select specific data:
# outliers exist, so price < 120000 and f-150 began production in 1975
ford = df[['price', 'year']][(df.manufacturer == 'ford') & (df.make == 'f-150') & (df.price < 120000) & (df.year >= 1975)]

 price    year
  1600  1992.0
 39760  2018.0
 11490  2014.0
  2500  1993.0
 17950  2014.0

Участок с seaborn :

import seaborn as sns

sns.regplot(x=ford.year, y=ford.price)
plt.show()

enter image description here

Вот график, без удаления выбросов:

  • Участок плоский, потому что максимальная цена составляет 8.888889e+07
    • Вы установили plt.ylim(1000,100000), поэтому выбросы не отображаются
  • Я принял произвольное решение исключить все цены выше $ 120 тыс., Потому что я знаю, что это нереальная цена для этого образца.
  • Простое удаление выбросов не всегда лучший вариант.

enter image description here

print(ford.describe())

              price          year
count  1.127000e+04  11270.000000
mean   2.405777e+04   2010.459184
std    8.372461e+05      6.454361
min    0.000000e+00   1975.000000
25%    5.300000e+03   2007.000000
50%    1.548750e+04   2012.000000
75%    2.549500e+04   2015.000000
max    8.888889e+07   2020.000000

График выполнения регрессии с помощью sklearn

import matplotlib.pyplot as plt

plt.scatter(X_test, y_test)
plt.plot(X_test, y_pred, color='violet', linewidth=3)
plt.show()

enter image description here

График X_test & y_pred в sns.regplot():

sns.regplot(x=ford.year, y=ford.price)
sns.scatterplot(X_test.flatten(), y_pred.flatten(), color='r')
plt.show()

enter image description here

...