Подгонка линейной модели Scikit-learn после запуска графика matplotlib возвращает ошибку значения - PullRequest
0 голосов
/ 06 мая 2020

Я запускаю код из первой главы Практического машинного обучения Орелиена Жерона с помощью Scikit-Learn и TensorFlow.

Код, который я пытаюсь запустить:

# Code example
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sklearn.linear_model

# Load the data
oecd_bli = pd.read_csv(datapath + "oecd_bli_2015.csv", thousands=',')
gdp_per_capita = pd.read_csv(datapath + "gdp_per_capita.csv",thousands=',',delimiter='\t',
                             encoding='latin1', na_values="n/a")

# Prepare the data
country_stats = prepare_country_stats(oecd_bli, gdp_per_capita)
X = np.c_[country_stats["GDP per capita"]]
y = np.c_[country_stats["Life satisfaction"]]

# Visualize the data
country_stats.plot(kind='scatter', x="GDP per capita", y='Life satisfaction')
plt.show()

# Select a linear model
model = sklearn.linear_model.LinearRegression()

# Train the model
model.fit(X, y)

Он не работает на шаге model.fit(X, y) со следующей трассировкой:

ValueError                                Traceback (most recent call last)
 in 
     23 
     24 # # Train the model
---> 25 model.fit(X, y)
     26 
     27 # # Make a prediction for Cyprus

~\AppData\Local\Programs\Python\venv\ds\lib\site-packages\sklearn\linear_model\_base.py in fit(self, X, y, sample_weight)
    531         else:
    532             self.coef_, self._residues, self.rank_, self.singular_ = \
--> 533                 linalg.lstsq(X, y)
    534             self.coef_ = self.coef_.T
    535 

~\AppData\Local\Programs\Python\venv\ds\lib\site-packages\scipy\linalg\basic.py in lstsq(a, b, cond, overwrite_a, overwrite_b, check_finite, lapack_driver)
   1223             raise LinAlgError("SVD did not converge in Linear Least Squares")
   1224         if info < 0:
-> 1225             raise ValueError('illegal value in %d-th argument of internal %s'
   1226                              % (-info, lapack_driver))
   1227         resids = np.asarray([], dtype=x.dtype)

ValueError: illegal value in 4-th argument of internal None

Однако, когда я повторно запускаю функцию подгонки без команды plt.show(), она работает нормально:

country_stats.plot(kind='scatter', x="GDP per capita", y='Life satisfaction')

model.fit(X, y) # works OK

# # Make a prediction for Cyprus
X_new = [[22587]]  # Cyprus' GDP per capita
print(model.predict(X_new)) # outputs [[ 5.96242338]]

Очень странное поведение. Не уверен, что это связано с моими версиями пакетов. Вот мои текущие версии пакета:

pip freeze | grep -E "numpy|pandas|scipy|matplotlib|sci"
matplotlib==3.2.1
numpy==1.18.4
pandas==0.25.3
scikit-image==0.16.2
scikit-learn==0.22
scipy==1.4.1

1 Ответ

0 голосов
/ 06 мая 2020

Я запускал код 10 раз, и он успешно завершился. Похоже, вы что-то упустили в своем коде. Полный код, 10 попыток фрагмента кода, который ломается, результаты распечатываются.

# Common imports
import numpy as np
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sklearn.linear_model

# to make this notebook's output stable across runs
np.random.seed(42)

# To plot pretty figures
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt

mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)

def prepare_country_stats(oecd_bli, gdp_per_capita):
    oecd_bli = oecd_bli[oecd_bli["INEQUALITY"]=="TOT"]
    oecd_bli = oecd_bli.pivot(index="Country", columns="Indicator", values="Value")
    gdp_per_capita.rename(columns={"2015": "GDP per capita"}, inplace=True)
    gdp_per_capita.set_index("Country", inplace=True)
    full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita,
                                  left_index=True, right_index=True)
    full_country_stats.sort_values(by="GDP per capita", inplace=True)
    remove_indices = [0, 1, 6, 8, 33, 34, 35]
    keep_indices = list(set(range(36)) - set(remove_indices))
    return full_country_stats[["GDP per capita", 'Life satisfaction']].iloc[keep_indices]


# Load the data
oecd_bli = pd.read_csv("oecd_bli_2015.csv", thousands=',')
gdp_per_capita = pd.read_csv("gdp_per_capita.csv",thousands=',',delimiter='\t',
                             encoding='latin1', na_values="n/a")


oecd_bli.head(3)
#  LOCATION    Country INDICATOR  ... Value Flag Codes            Flags
#0      AUS  Australia   HO_BASE  ...   1.1          E  Estimated value
#1      AUT    Austria   HO_BASE  ...   1.0        NaN              NaN
#2      BEL    Belgium   HO_BASE  ...   2.0        NaN              NaN


gdp_per_capita.head(3)
#                                            Subject Descriptor  ... #Estimates Start After
#Country                                                         ...
#Afghanistan  Gross domestic product per capita, current prices  ...                #2013.0
#Albania      Gross domestic product per capita, current prices  ...                #2010.0
#Algeria      Gross domestic product per capita, current prices  ...                #2014.0


# Prepare the data
country_stats = prepare_country_stats(oecd_bli, gdp_per_capita)
X = np.c_[country_stats["GDP per capita"]]
y = np.c_[country_stats["Life satisfaction"]]

X[0:3]
#array([[ 9054.914],
#       [ 9437.372],
#       [12239.894]])

y[0:3]
#array([[6. ],
#       [5.6],
#       [4.9]])

results = list()
for i in range(10):
    # Visualize the data
    country_stats.plot(kind='scatter', x="GDP per capita", y='Life satisfaction')
    plt.show()

    # Select a linear model
    model = sklearn.linear_model.LinearRegression()

    # Train the model
    model.fit(X, y)

    # Make a prediction for Cyprus
    X_new = [[22587]]  # Cyprus' GDP per capita
    results.append(model.predict(X_new))


print(results)
#[array([[5.96242338]]),
# array([[5.96242338]]),
# array([[5.96242338]]),
# array([[5.96242338]]),
# array([[5.96242338]]),
# array([[5.96242338]]),
# array([[5.96242338]]),
# array([[5.96242338]]),
# array([[5.96242338]]),
# array([[5.96242338]])]

И:

pip freeze | grep -E "numpy|pandas|scipy|matplotlib|sci"
matplotlib==3.1.2
numpy==1.17.4
pandas==0.25.3
pandas-flavor==0.2.0
scikit-learn==0.22.1
scikit-plot==0.3.7
scipy==1.4.1
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...