Эффективно выберите точки данных, которые находятся близко к центру кластера - PullRequest
0 голосов
/ 05 декабря 2018

Предположим, у меня есть такой набор данных:

import numpy as np
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib inline

X,y = make_blobs(random_state=101) # My data

palette = sns.color_palette('bright',3)
sns.scatterplot(X[:,0], X[:,1],palette=palette,hue=y) # Visualizing the data

enter image description here

Я хотел бы выбрать данные, которые находятся близко к центру кластера,Скажем, я хочу выбрать данные близко к центру из cluster '0', в настоящее время я делаю так:

label_0 = X[y==0] # Want to select data from the label '0'

data_index = 2 # Manaully pick the point
sns.scatterplot(X[:,0], X[:,1],palette=palette,hue=y)
plt.scatter(label_0[data_index][0],label_0[data_index][1],marker='*')

enter image description here

Так как этоне близко к центру, я меняю индекс и выбираю другой.

data_index = 4
sns.scatterplot(X[:,0], X[:,1],palette=palette,hue=y)
plt.scatter(label_0[data_index][0],label_0[data_index][1],marker='*')

Теперь это близко.Но мне интересно, есть ли более эффективный способ добиться этого?Это возможно для небольшого набора данных, такого как этот, но если мой набор данных имеет тысячи точек, я не думаю, что этот метод будет работать больше.enter image description here

1 Ответ

0 голосов
/ 05 декабря 2018

Один из подходов - использовать алгоритм K-средних .Это поможет вам найти центры каждого кластера.

С учетом вашего набора данных, шаги будут:

1) Найти количество кластеров

num_clusters=len(np.unique(y)) #here 3

2)Примените кластеризацию k-средних от scikit к вашим данным

from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(X)

3) Найдите центр каждого кластера

centers=kmeans.cluster_centers_ # gives the centers of each cluster
# array([[ 0.26542862,  1.85466779],
#        [-9.50316411, -6.52747391],
#        [ 3.64354311,  6.62683956]])

4) Поскольку эти центры не могут быть частьюиз ваших исходных данных, нам нужно найти ближайшие точки к ним

from scipy import spatial

def nearest_point(array,query):
    return array[spatial.KDTree(array).query(query)[1]]

nearest_centers=np.array([nearest_point(X,center) for center in centers])
# array([[ 0.19313183,  1.80387958],
#       [-9.12488396, -6.32638926],
#       [ 3.65986315,  6.69035824]])

5) Построить исходные данные и центры

sns.scatterplot(X[:,0], X[:,1],palette=palette,hue=y) 
for nc in nearest_centers:
    plt.scatter(nc[0],nc[1],marker='*',color='r')

Центры показаны красными крестиками:

The centers are shows by the red crosses

...