Машинное обучение и линейная регрессия. Ожидаемый двумерный массив, данные формы - PullRequest
0 голосов
/ 03 июля 2018

Я очень новичок в машинном обучении, и поэтому, используя набор данных Pokemon, я решил написать тестовую программу для прогнозирования "коэффициента вылова" на основе "общих" данных. Я хотел использовать линейную регрессию для моих тренировочных данных. Но когда я запускаю свою программу, я получаю следующую ошибку:

Ожидаемый двумерный массив, вместо него получен одномерный массив: array = ['190' '90' '45' '125' «190», «75», «45», «120», «200», «45», «190», «60», «225», «90», «3», «45», «150». «120», «45», «3», «3», «255», «90», «45», «45», «45», «255», «225», «190», «190». «255», «90», «45», «45», «30», «45», «45», «90», «190», «90», «45», «90», «60». «45», «60», «75», «55», «75», «45», «45», «3», «255», «45», «3», «45», «90», «190». '60' '190' '200' '225' '75' '45' '45' '45' '200' '120' '120' '255' «60», «45», «45», «75», «60», «60», «190», «75», «45», «120», «190», «200», «235». «45», «45», «90», «30», «45», «45», «170», «235», «45», «190», «60», «75», «180». '45' '235' '190' '45' '120' '45' '75' '190' '45' '45' '45' '45' '45' «75», «45», «45», «190», «45», «75», «3», «45», «60», «200», «45», «45», «255». «255», «120», «45», «255», «125», «120», «60», «45», «45», «60», «255», «45». «180», «60», «45», «60», «3», «25», «120», «45», «3», «3», «45», «75», «30», «45». «255», «30», «75», «255», «255», «180», «255», «45», «45», «120», «255», «75». «30», «45», «75», «45», «255», «120», «45», «45», «45», «190», «45», «75», «45». '45' '3' '60' '30' '60' '200' '45' '75' '120' '25' '255' '45' '255' «200», «190», «190», «120», «45», «90», «170», «45», «75», «60», «100», «45». «45», «90», «45», «45», «45», «255», «60», «90», «140», «45», «90», «75», «200». «45», «45», «255», «120», «3», «45», «75», «200», «255», «225», «120», «120». «200», «45», «45», «50», «190», «45», «45», «45», «45», «45», «45», «30», «3», «3». '255' '45' '45' '255' '120' '225' '45' '75' '75' '45' '60' '255' '60' '60' '45' '120' '255' '45' '225' '255' '45' '45' '3' '255' '190' '30' «190», «45», «45», «120», «75», «25», «75», «255», «45», «120», «100», «3», «65». '45' '75' '180' '45' '45' '3' '255' '45' '45' '90' '225' '190' '45' '255' '3' '190' '70' '3' '120' '45' '45' '50' '200' '190' '255' '55' «150», «45», «3», «25», «60», «45», «120», «45», «205», «60», «45», «45», «255». «30», «120», «75», «45», «90», «45», «45», «60», «190», «45», «45», «90», «45». «3», «75», «90», «200», «180», «45», «45», «75», «90», «45», «3», «120», «45». «45», «45», «45», «75», «45», «155», «45», «55», «45», «30», «45», «150», «255». 45 '75' 180 '' 15 '' 190 '' 255 '' 75 '' 190 '' 45 '' 190 '' 90 '' 255 ' «45», «45», «45», «190», «3», «60», «45», «60», «60», «255», «25», «145», «45». «45», «120», «50», «45», «120», «45», «255», «45», «45», «45», «50», «225», «30». '75' '120' '3' '45' '120' '30' '45' '255' '90' '3' '3' '120' '45' «127», «120», «200», «255», «25», «45», «75», «120», «255», «190», «220», «45». '65' '45' '90' '60' '200' '190' '190' '120' '190' '90' '45' '120' '75' '190' '75' '90' '120' '90' '75' '45' '190' '45' '100' '60' '3' '45' '90' '190' '255' '45' '190' '45' '45' '25' '60' '60' '45' '190' «45» 190 «30» 190 «45» 190 «255» 45 «45» 3 «120» 3 «45» '35' '120' '190' '255' '190' '45' '45' '45' '45' '255' '190' '45' «190», «225», «45», «190», «255», «45», «190», «45», «255», «75», «45», «90». «120», «30», «180», «190», «100», «255», «235», «75», «60», «190», «160», «45». «3», «120», «45», «3», «120», «45», «45», «45», «127», «75», «190», «140», «75». «225», «60», «45», «75», «120», «190», «190», «90», «3», «45», «150», «120», «30». «50», «45», «60», «190», «255», «125», «120», «75», «60», «90», «140»].

Измените ваши данные, используя array.reshape (-1, 1), если ваши данные имеют один feature или array.reshape (1, -1), если он содержит один образец.

Чтобы исправить свою ошибку, я попытался изменить свой список x_train, так как кажется, что он упомянут выше, но я все еще получаю ту же ошибку. Возможно, мой синтаксис выключен? Я попробовал x_train.reshape(-1, 1) и x_train = x_train.reshape(-1, 1) из другого предложения, которое нашел, но безуспешно.

Вот (грубый) код, который я написал до сих пор:

from sklearn import cross_validation
from sklearn import svm
from sklearn.feature_selection import RFE
from sklearn.model_selection import train_test_split
from sklearn import linear_model
import numpy as np
import matplotlib as plt
import csv

# Create linear regression object
regr = linear_model.LinearRegression()

# Create lists and append data -- we want to predict the catch rate!
total = []
catch_rate = []

with open("pokemon.csv") as f:
    reader = csv.reader(f)
    next(reader) # skip header
    for row in reader:
        total.append(row[5])
        catch_rate.append(row[21])

x_train, x_test, y_train, y_test = 
cross_validation.train_test_split(catch_rate, total, test_size=0.25, 
random_state=0)


# Train the model using the training sets
regr.fit(x_train, y_train)

# Make predictions using the testing set
pokemon_y_pred = regr.predict(x_test)

# Plot outputs
plt.scatter(x_test, y_test,  color='black')
plt.plot(x_test, pokemon_y_pred, color='blue', linewidth=3)

plt.xticks(())
plt.yticks(())

plt.show()

Может быть, я упустил что-то еще в моем понимании кода? Опять же, я учу себя, поэтому я очень признателен за любую помощь.

1 Ответ

0 голосов
/ 03 июля 2018

Используйте pandas фрейм данных вместо list типа. Также обратите внимание, что первый элемент функции train_test_split должен быть кадром данных как минимум с двумя столбцами.


Итак, если ваш CSV-файл выглядит следующим образом:

Id,Name,Type_1,Type_2,Total,HP,Attack,Defense,Sp_Atk,Sp_Def,Speed,Generation,isLegendary,Color,hasGender,Pr_Male,Egg_Group_1,Egg_Group_2,hasMegaEvolution,Height_m,Weight_kg,Catch_Rate
1,Bulbasaur,Grass,Poison,318,45,49,49,65,65,45,1,False,Green,True,0.875,Monster,Grass,False,0.71,6.9,45
2,Ivysaur,Grass,Poison,405,60,62,63,80,80,60,1,False,Green,True,0.875,Monster,Grass,False,0.99,13,45
3,Venusaur,Grass,Poison,525,80,82,83,100,100,80,1,False,Green,True,0.875,Monster,Grass,True,2.01,100,45
4,Charmander,Fire,,309,39,52,43,60,50,65,1,False,Red,True,0.875,Monster,Dragon,False,0.61,8.5,45
5,Charmeleon,Fire,,405,58,64,58,80,65,80,1,False,Red,True,0.875,Monster,Dragon,False,1.09,19,45

И используя следующий код:

from sklearn import cross_validation
from sklearn import svm
from sklearn.feature_selection import RFE
from sklearn.model_selection import train_test_split
from sklearn import linear_model
import numpy as np
import matplotlib as plt
import pandas as pd #import pandas

# Create linear regression object
regr = linear_model.LinearRegression()

#load csv file with pandas
df = pd.read_csv("pokemon.csv")
#remove all string columns
df = df.drop(['Name', 'Type_1','Type_2','Color','Egg_Group_1','Egg_Group_2'], axis=1)

y= df.Catch_Rate

x_train, x_test, y_train, y_test = cross_validation.train_test_split(df, y, test_size=0.25, random_state=0)


# Train the model using the training sets
regr.fit(x_train, y_train)

# Make predictions using the testing set
pokemon_y_pred = regr.predict(x_test)

print pokemon_y_pred


# [ code continuation ...]

Вы получите:

[ 45.  45.]
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...