Дерево решений precision_score дает "ValueError: Найдены входные переменные с непоследовательным количеством выборок" - PullRequest
0 голосов
/ 13 октября 2018

Я пытаюсь создать дерево решений с заданными данными.Но по какой-то причине accuracy_score дает

ValueError: Найдены входные переменные с непоследовательным количеством выборок:

, когда я разделяю свои тренировочные данные на проверку (% 20) иобучение (% 80).

Вот как я разбил свои данные:

from sklearn.utils import shuffle

from sklearn.model_selection import train_test_split

# stDt shuffled training set

stDt = shuffle(tDt) 

#divide shuffled training set to training and validation set

stDt, vtDt = train_test_split(stDt,train_size=0.8, shuffle=False)

print(tDt.shape)
print(stDt.shape)
print(vtDt.shape)

Вот как я тренирую данные:

#attibutes and labels of training set

attributesT =  stDt.values

labelsT = stDt.label


# Train Decision tree classifiers
from sklearn.tree import DecisionTreeClassifier


dtree1 = DecisionTreeClassifier(min_samples_split = 1.0)

dtree2 = DecisionTreeClassifier(min_samples_split = 3)

dtree3 = DecisionTreeClassifier(min_samples_split = 5)



fited1 = dtree1.fit(attributesT,labelsT)

fited2 = dtree2.fit(attributesT,labelsT)

fited3 = dtree3.fit(attributesT,labelsT)

Вот тест и точностьПартитура:

from sklearn.metrics import accuracy_score

ret1 = fited1.predict(stDt)

ret2 = fited2.predict(stDt)

ret3 = fited3.predict(stDt)

print(accuracy_score(vtDt.label,ret1))

1 Ответ

0 голосов
/ 14 октября 2018

Полученная ошибка является ожидаемой, поскольку вы пытаетесь сравнить прогнозы, полученные из вашего обучающего набора (ret1 = fited1.predict(stDt)), с метками вашего проверки набора (vtDt.label).

Вот правильный способ повысить точность обучения и проверки для вашей модели fitted1 (аналогично для остальных):

# predictions on the training set:
ret1 = fitted1.predict(stDt)

# training accuracy:
accuracy_score(stDt.label,ret1)

# predictions on the validation set:
pred1 = fitted1.predict(vtDt)

# validation accuracy:
accuracy_score(vtDt.label,pred1)
...