Не могу заставить программу работать. Я хочу, чтобы программа делала прогноз, для какого сорта пшеницы она соответствует в соответствии с вводом пользователя. Однако я снова получаю результат индекса за пределами допустимого диапазона. Я подумал, может быть, это как-то связано с возвратом.
Вот код
import pandas as pd
import numpy as np
import os
import csv
import sys
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB
with open('example.csv', 'r') as csvfile:
readCSV = csv.reader(csvfile, delimiter=',')
next(readCSV)
Wheat_Varietys = []
id = 0
for line in readCSV:
area = line[0]
perimiter = line[1]
compactness = line[2]
kernel_length = line[3]
kernel_width = line[4]
assymetry_coeff = line[5]
groove_length = line[6]
Wheat_Variety = line[7]
Wheat_Varietys.append({ 'id': id,
'attributes':
{
'area': area,
'perimiter': perimiter,
'compactness': compactness,
'kernel_length': kernel_length,
'kernel_width': kernel_width,
'assymetry_coeff': assymetry_coeff,
'groove_length': groove_length,
'Wheat_Variety': Wheat_Variety,
}})
id+=1
whatArea = input('What area do you wish to enter?')
whatPeri = input('What perimiter do you wish to enter?')
whatComp = input('What compactness do you wish to enter?')
whatKerL = input('What kernel length do you wish to enter?')
whatKerW = input('What kernel width do you wish to enter?')
whatAssym = input('What Assymetry Coefficient do you wish to enter?')
whatGroove = input('What Groove Lenght do you wish to enter?')
wheat_dataset = pd.read_csv('example.csv')
wheat_data_cols = wheat_dataset[['area', 'perimeter', 'compactness',
'kernel_length', 'kernel_width', 'assymetry_coeff' , 'groove_length']]
x_train, x_test, y_train, y_test = train_test_split(wheat_data_cols,
wheat_dataset['Wheat_Variety'], train_size = 0.8, test_size = 0.2,
random_state = 0)
nb = GaussianNB()
nb.fit(x_train, y_train)
def predict_wheat_classification(whatArea, whatPeri, whatComp, whatKerL,
whatKerW, whatAssym, whatGroove):
x_new = np.array([[whatArea, whatPeri, whatComp, whatKerL, whatKerW,
whatAssym, whatGroove]])
prediction = nb.predict(x_new)
return prediction[0]
print(predict_wheat_classification(2, 2, 2, 2, 2, 2, 2))
, а вот CSV
example.csv