Я сделал EDA для этого набора данных белого скули и я пытаюсь найти 3 предиктора качества и провести линейную регрессию по ним.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
wine = "~/Desktop/datasets/winequality-white.csv"
# Load the data
df = pd.read_csv(wine,sep=";")
df.head()
# Look at the information regarding its columns.
df.info()
# non-null floats also validated by √null_release_mask = df['fixed
acidity'].isnull()
Я пытаюсь разделить тест на поезда и выбрать 3 предиктора для прогнозирования качества
from sklearn.model_selection import train_test_split
X = df[["alcohol", "pH","free sulfur dioxide"]]
y = df["quality"]
X_train, X_test, y_train, y_test = train_test_split(X, y,
test_size=0.3, random_state=42)
print(len(X_train), len(X_test))
print(len(y_train), len(y_test))`
from sklearn.linear_model import LinearRegression
model = LinearRegression()
model.fit(X_train,y_train)
import numpy as np
x_values_to_plot = np.linspace(0, df[["alcohol", "pH","free sulfur
dioxide"]].max(), 15)
y_values_to_plot = (x_values_to_plot * model.coef_) + model.intercept_
fig, ax = plt.subplots(figsize=(6,6))
ax.scatter(df[["alcohol", "pH","free sulfur dioxide"]], df["quality"],
label="data", alpha=0.2)
ax.plot(x_values_to_plot, y_values_to_plot, label="regression_line of
white wines", c="r")
ax.legend(loc="best")
plt.show()
Однако я получаю эту ошибку:
---------------------------------------------------------------------------
ValueError Traceback (most recent call
last)
<ipython-input-68-c52d735932ab> in <module>()
1 import numpy as np
2
----> 3 x_values_to_plot = np.linspace(0, df[["alcohol", "pH","free
sulfur dioxide"]].max(), 15)
4 y_values_to_plot = (x_values_to_plot * model.coef_) +
model.intercept_
5
~/anaconda3/lib/python3.7/site-packages/numpy/core/function_base.py in
linspace(start, stop, num, endpoint, retstep, dtype)
122 if num > 1:
123 step = delta / div
--> 124 if step == 0:
125 # Special handling for denormal numbers, gh-5437
126 y /= div
*ValueError: The truth value of an array with more than one element
is
ambiguous. Use a.any() or a.all()*
Любая помощь будет принята с благодарностью, я новичок в StackOverflow, поэтому помните о формате вопроса и дайте мне знать, что я могуулучшить.Спасибо