Как я могу кластеризовать граф, созданный в NetworkX? - PullRequest
1 голос
/ 14 июля 2020

Я пытаюсь применить кластеризацию к набору данных. Перед этим мне нужно разделить график на n кластеров, и я не знаю, как это сделать.

1 Ответ

1 голос
/ 15 июля 2020

Предположим, что список краев ваших невзвешенных и ненаправленных графа был сохранен в файле edge.txt. Вы можете выполнить следующие шаги для кластеризации узлов графа.

Шаг 1: получить встраивание каждого узла в граф. Это означает, что вам нужно получить непрерывное векторное представление для каждого узла. Вы можете использовать такие методы встраивания графов, как node2ve c, deepwalk , et c, чтобы получить встраивание. Обратите внимание, что такие методы сохраняют структурное сходство между узлами графа в векторном представлении (пространстве вложения). В следующем примере показано, как это сделать.

import networkx as nx
G=nx.Graph();
G=nx.read_edgelist("edges.txt") # edges.txt contains the edge list of your graph

# help to draw https://networkx.github.io/documentation/networkx-1.9/examples/drawing/labels_and_colors.html
nx.draw(G,with_labels = True,node_color='b',node_size=500);

from node2vec import Node2Vec
# Generate walks
node2vec = Node2Vec(G, dimensions=2, walk_length=20, num_walks=10,workers=4)
# Learn embeddings 
model = node2vec.fit(window=10, min_count=1)
#model.wv.most_similar('1')
model.wv.save_word2vec_format("embedding.emb") #save the embedding in file embedding.emb

Шаг 2: примените метод кластеризации. Получив векторное представление узлов, вы можете кластеризовать узлы на основе этих представлений. См. Пример ниже.

from sklearn.cluster import KMeans
import numpy as np


X = np.loadtxt("embedding.emb", skiprows=1) # load the embedding of the nodes of the graph
#print(X)
# sort the embedding based on node index in the first column in X
X=X[X[:,0].argsort()]; 
#print(X)
Z=X[0:X.shape[0],1:X.shape[1]]; # remove the node index from X and save in Z

kmeans = KMeans(n_clusters=2, random_state=0).fit(Z) # apply kmeans on Z
labels=kmeans.labels_  # get the cluster labels of the nodes.
print(labels)
...