DATASET: Реклама в социальных сетях
Я получаю сообщение об ошибке. Я просто пытаюсь внедрить K ближайшего соседа с нуля.
У меня есть набор социальных объявлений. Исходя из возраста и ожидаемой зарплаты, я буду классифицировать купленных или нет. Но я получаю сообщение об ошибке «KeyError: 227».
Вот основной код:
import numpy as np
import pandas as pd
from collections import Counter
def euclidean_distance(x1, x2):
return np.sqrt(np.sum((x1 - x2)**2))
class KNN:
def __init__(self, k=3):
self.k = k
def fit(self, X, y):
self.X_train = X
self.y_train = y
def predict(self, X):
y_pred = [self._predict(x) for x in X]
return np.array(y_pred)
def _predict(self, x):
# Compute distances between x and all examples in the training set
distances = [euclidean_distance(x, x_train) for x_train in self.X_train]
# Sort by distance and return indices of the first k neighbors
k_idx = np.argsort(distances)[:self.k]
# Extract the labels of the k nearest neighbor training samples
k_neighbor_labels = [self.y_train[i] for i in k_idx]
# return the most common class label
most_common = Counter(k_neighbor_labels).most_common(1)
return most_common[0][0]
def accuracy(y_true, y_pred):
accuracy = np.sum(y_true == y_pred) / len(y_true)
return accuracy
socialads =pd.read_csv("socialads.csv")
print(socialads)
X=socialads.iloc[:,2:3].values
y=socialads.iloc[:,4]
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)
# Inspect data
print(X_train.shape)
print(X_train)
print(y_train.shape)
print(y_train)
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
cmap = ListedColormap(['#FF0000', '#00FF00'])
plt.figure()
plt.xlabel('X')
plt.ylabel('y')
plt.title('KNN Classification for Social Ads')
plt.scatter(X,y, c=y, cmap=cmap,edgecolor='k', s=30)
plt.show()
k = 3
clf = KNN(k=k)
clf.fit(X_train, y_train)
predictions = clf.predict(X_test)
accuracy=np.sum(predictions == y_test)/len(y_test)
print("custom KNN classification accuracy", accuracy(y_test, predictions))
Вот ошибка:
KeyError Traceback (most recent call last)
<ipython-input-645-06575c57434f> in <module>
2 clf = KNN(k=k)
3 clf.fit(X_train, y_train)
4 predictions = clf.predict(X_test)
5 accuracy=np.sum(predictions == y_test)/len(y_test)
6 print("custom KNN classification accuracy", accuracy(y_test, predictions))
<ipython-input-640-af87127ae54f> in predict(self, X)
9
10 def predict(self, X):
11 y_pred = [self._predict(x) for x in X]
12 return np.array(y_pred)
13
<ipython-input-640-af87127ae54f> in <listcomp>(.0)
9
10 def predict(self, X):
11 y_pred = [self._predict(x) for x in X]
12 return np.array(y_pred)
13
<ipython-input-640-af87127ae54f> in _predict(self, x)
18 k_idx = np.argsort(distances)[:self.k]
19 # Extract the labels of the k nearest neighbor training samples
20 k_neighbor_labels = [self.y_train[i] for i in k_idx]
21 # return the most common class label
22 most_common = Counter(k_neighbor_labels).most_common(1)
<ipython-input-640-af87127ae54f> in <listcomp>(.0)
18 k_idx = np.argsort(distances)[:self.k]
19 # Extract the labels of the k nearest neighbor training samples
20 k_neighbor_labels = [self.y_train[i] for i in k_idx]
21 # return the most common class label
22 most_common = Counter(k_neighbor_labels).most_common(1)
~\Anaconda3\lib\site-packages\pandas\core\series.py in __getitem__(self, key)
765 key = com._apply_if_callable(key, self)
766 try:
767 result = self.index.get_value(self, key)
768
769 if not is_scalar(result):
~\Anaconda3\lib\site-packages\pandas\core\indexes\base.py in get_value(self, series, key)
3116 try:
3117 return self._engine.get_value(s, k,
3118 tz=getattr(series.dtype, 'tz', None))
3119 except KeyError as e1:
3120 if len(self) > 0 and self.inferred_type in ['integer', 'boolean']:
pandas\_libs\index.pyx in pandas._libs.index.IndexEngine.get_value()
pandas\_libs\index.pyx in pandas._libs.index.IndexEngine.get_value()
pandas\_libs\index.pyx in pandas._libs.index.IndexEngine.get_loc()
pandas\_libs\hashtable_class_helper.pxi in pandas._libs.hashtable.Int64HashTable.get_item()
pandas\_libs\hashtable_class_helper.pxi in pandas._libs.hashtable.Int64HashTable.get_item()
KeyError: 227
PLZ ИДЕНТИФИКАЦИЯ ПРОБЛЕМЫ.