Мне нужно было выполнить технику кластеризации, описанную ОП, миллионы раз для довольно больших списков, и поэтому я хотел определить, какой из методов, предложенных выше, является и наиболее точным, и наиболее производительным.Испытания для входных списков размером от 2 ^ 1 до 2 ^ 10 для каждого метода, описанного выше, с использованием одного и того же списка ввода для каждого метода и измерения среднего времени выполнения для каждого алгоритма, предложенного выше, в миллисекундах.Вот результаты:
Эти результаты помогли мне увидеть, что из методов, которые последовательно возвращают правильные результаты, @ jochen's является самым быстрым.Среди тех методов, которые не всегда возвращают правильные результаты, решение mak часто не включает все элементы ввода (т. Е. Отсутствует список членов списка), а решения braaksma, cmangla и asterisk не гарантируются максимально объединенными.
Интересно, что два самых быстрых, правильных алгоритма имеют на данный момент два верхних числа голосов в правильном порядке.
Вот код, используемый для запуска тестов:
from networkx.algorithms.components.connected import connected_components
from itertools import chain
from random import randint, random
from collections import defaultdict, deque
from copy import deepcopy
from multiprocessing import Pool
import networkx
import datetime
import os
##
# @mimomu
##
def mimomu(l):
l = deepcopy(l)
s = set(chain.from_iterable(l))
for i in s:
components = [x for x in l if i in x]
for j in components:
l.remove(j)
l += [list(set(chain.from_iterable(components)))]
return l
##
# @Howard
##
def howard(l):
out = []
while len(l)>0:
first, *rest = l
first = set(first)
lf = -1
while len(first)>lf:
lf = len(first)
rest2 = []
for r in rest:
if len(first.intersection(set(r)))>0:
first |= set(r)
else:
rest2.append(r)
rest = rest2
out.append(first)
l = rest
return out
##
# Nx @Jochen Ritzel
##
def jochen(l):
l = deepcopy(l)
def to_graph(l):
G = networkx.Graph()
for part in l:
# each sublist is a bunch of nodes
G.add_nodes_from(part)
# it also imlies a number of edges:
G.add_edges_from(to_edges(part))
return G
def to_edges(l):
"""
treat `l` as a Graph and returns it's edges
to_edges(['a','b','c','d']) -> [(a,b), (b,c),(c,d)]
"""
it = iter(l)
last = next(it)
for current in it:
yield last, current
last = current
G = to_graph(l)
return list(connected_components(G))
##
# Merge all @MAK
##
def mak(l):
l = deepcopy(l)
taken=[False]*len(l)
l=map(set,l)
def dfs(node,index):
taken[index]=True
ret=node
for i,item in enumerate(l):
if not taken[i] and not ret.isdisjoint(item):
ret.update(dfs(item,i))
return ret
def merge_all():
ret=[]
for i,node in enumerate(l):
if not taken[i]:
ret.append(list(dfs(node,i)))
return ret
result = list(merge_all())
return result
##
# @cmangla
##
def cmangla(l):
l = deepcopy(l)
len_l = len(l)
i = 0
while i < (len_l - 1):
for j in range(i + 1, len_l):
# i,j iterate over all pairs of l's elements including new
# elements from merged pairs. We use len_l because len(l)
# may change as we iterate
i_set = set(l[i])
j_set = set(l[j])
if len(i_set.intersection(j_set)) > 0:
# Remove these two from list
l.pop(j)
l.pop(i)
# Merge them and append to the orig. list
ij_union = list(i_set.union(j_set))
l.append(ij_union)
# len(l) has changed
len_l -= 1
# adjust 'i' because elements shifted
i -= 1
# abort inner loop, continue with next l[i]
break
i += 1
return l
##
# @pillmuncher
##
def pillmuncher(l):
l = deepcopy(l)
def connected_components(lists):
neighbors = defaultdict(set)
seen = set()
for each in lists:
for item in each:
neighbors[item].update(each)
def component(node, neighbors=neighbors, seen=seen, see=seen.add):
nodes = set([node])
next_node = nodes.pop
while nodes:
node = next_node()
see(node)
nodes |= neighbors[node] - seen
yield node
for node in neighbors:
if node not in seen:
yield sorted(component(node))
return list(connected_components(l))
##
# @NicholasBraaksma
##
def braaksma(l):
l = deepcopy(l)
lists = sorted([sorted(x) for x in l]) #Sorts lists in place so you dont miss things. Trust me, needs to be done.
resultslist = [] #Create the empty result list.
if len(lists) >= 1: # If your list is empty then you dont need to do anything.
resultlist = [lists[0]] #Add the first item to your resultset
if len(lists) > 1: #If there is only one list in your list then you dont need to do anything.
for l in lists[1:]: #Loop through lists starting at list 1
listset = set(l) #Turn you list into a set
merged = False #Trigger
for index in range(len(resultlist)): #Use indexes of the list for speed.
rset = set(resultlist[index]) #Get list from you resultset as a set
if len(listset & rset) != 0: #If listset and rset have a common value then the len will be greater than 1
resultlist[index] = list(listset | rset) #Update the resultlist with the updated union of listset and rset
merged = True #Turn trigger to True
break #Because you found a match there is no need to continue the for loop.
if not merged: #If there was no match then add the list to the resultset, so it doesnt get left out.
resultlist.append(l)
return resultlist
##
# @Rumple Stiltskin
##
def stiltskin(l):
l = deepcopy(l)
hashdict = defaultdict(int)
def hashit(x, y):
for i in y: x[i] += 1
return x
def merge(x, y):
sums = sum([hashdict[i] for i in y])
if sums > len(y):
x[0] = x[0].union(y)
else:
x[1] = x[1].union(y)
return x
hashdict = reduce(hashit, l, hashdict)
sets = reduce(merge, l, [set(),set()])
return list(sets)
##
# @Asterisk
##
def asterisk(l):
l = deepcopy(l)
results = {}
for sm in ['min', 'max']:
sort_method = min if sm == 'min' else max
l = sorted(l, key=lambda x:sort_method(x))
queue = deque(l)
grouped = []
while len(queue) >= 2:
l1 = queue.popleft()
l2 = queue.popleft()
s1 = set(l1)
s2 = set(l2)
if s1 & s2:
queue.appendleft(s1 | s2)
else:
grouped.append(s1)
queue.appendleft(s2)
if queue:
grouped.append(queue.pop())
results[sm] = grouped
if len(results['min']) < len(results['max']):
return results['min']
return results['max']
##
# Validate no more clusters can be merged
##
def validate(output, L):
# validate all sublists are maximally merged
d = defaultdict(list)
for idx, i in enumerate(output):
for j in i:
d[j].append(i)
if any([len(i) > 1 for i in d.values()]):
return 'not maximally merged'
# validate all items in L are accounted for
all_items = set(chain.from_iterable(L))
accounted_items = set(chain.from_iterable(output))
if all_items != accounted_items:
return 'missing items'
# validate results are good
return 'true'
##
# Timers
##
def time(func, L):
start = datetime.datetime.now()
result = func(L)
delta = datetime.datetime.now() - start
return result, delta
##
# Function runner
##
def run_func(args):
func, L, input_size = args
results, elapsed = time(func, L)
validation_result = validate(results, L)
return func.__name__, input_size, elapsed, validation_result
##
# Main
##
all_results = defaultdict(lambda: defaultdict(list))
funcs = [mimomu, howard, jochen, mak, cmangla, braaksma, asterisk]
args = []
for trial in range(10):
for s in range(10):
input_size = 2**s
# get some random inputs to use for all trials at this size
L = []
for i in range(input_size):
sublist = []
for j in range(randint(5, 10)):
sublist.append(randint(0, 2**24))
L.append(sublist)
for i in funcs:
args.append([i, L, input_size])
pool = Pool()
for result in pool.imap(run_func, args):
func_name, input_size, elapsed, validation_result = result
all_results[func_name][input_size].append({
'time': elapsed,
'validation': validation_result,
})
# show the running time for the function at this input size
print(input_size, func_name, elapsed, validation_result)
pool.close()
pool.join()
# write the average of time trials at each size for each function
with open('times.tsv', 'w') as out:
for func in all_results:
validations = [i['validation'] for j in all_results[func] for i in all_results[func][j]]
linetype = 'incorrect results' if any([i != 'true' for i in validations]) else 'correct results'
for input_size in all_results[func]:
all_times = [i['time'].microseconds for i in all_results[func][input_size]]
avg_time = sum(all_times) / len(all_times)
out.write(func + '\t' + str(input_size) + '\t' + \
str(avg_time) + '\t' + linetype + '\n')
А для черчения:
library(ggplot2)
df <- read.table('times.tsv', sep='\t')
p <- ggplot(df, aes(x=V2, y=V3, color=as.factor(V1))) +
geom_line() +
xlab('number of input lists') +
ylab('runtime (ms)') +
labs(color='') +
scale_x_continuous(trans='log10') +
facet_wrap(~V4, ncol=1)
ggsave('runtimes.png')