Я представлю вам три варианта. Ваш базовый алгоритм может быть улучшен несколькими способами. Вместо добавления в разреженную матрицу, просто используйте предварительно выделенный массив, даже если вы не заполняете его полностью. Кроме того, вы можете конвертировать предметы в наборы только один раз в начале, чтобы избежать повторения работы. Итак, вы получите:
import numpy as np
def count_common_tokens(subjects):
n = len(subjects)
counts = np.zeros((n, n), dtype=np.int32)
subjects_sets = [set(subject) for subject in subjects]
for i1, subj_1 in enumerate(subjects_sets):
for i2 in range(i1 + 1, n):
subj_2 = subjects_sets[i2]
counts[i1, i2] = len(subj_1.intersection(subj_2))
return counts
Проблема имеет квадратичную сложность по своей природе. Но мы можем попытаться векторизовать его с помощью NumPy.
import numpy as np
def count_common_tokens_vec(subjects):
n = len(subjects)
# Concatenate all subjects
all_subjects = np.concatenate(subjects)
# Make subject ids from subject lengths
lens = [len(subject) for subject in subjects]
subject_ids = np.repeat(np.arange(n), lens)
# Find unique token ids
all_tokens, token_ids = np.unique(all_subjects, return_inverse=True)
# Make array where each row represents the token presents in each subject
subject_token = np.zeros((n, len(all_tokens)), dtype=np.int32)
np.add.at(subject_token, (subject_ids, token_ids), 1)
subject_token = subject_token.astype(bool)
# Logical and with itself to find number of common tokens
counts = np.count_nonzero(subject_token[:, np.newaxis] & subject_token[np.newaxis, :], axis=-1)
return counts
Это дает вам полную матрицу отсчетов (а не только верхний треугольник), но это может занять много памяти в порядке O(num_subjects x num_subjecs x num_tokens)
, так что, вероятно, это не будет хорошо работать для большой проблемы. Однако мы можем попытаться ускорить процесс с помощью Numba, если вам это действительно нужно. Это заставляет вас делать вещи немного по-другому, работая с массивами чисел, а не с наборами строк (возможно, здесь есть лучший способ сделать первую часть), но мы также можем получить желаемый результат с ним.
import numpy as np
import numba as nb
def count_common_tokens_nb(subjects):
n = len(subjects)
# Output array
counts = np.zeros((n, n), dtype=np.int32)
# Concatenate all subjects
all_subjects = np.concatenate(subjects)
# Find token ids for concatenated subjects
_, token_ids = np.unique(all_subjects, return_inverse=True)
# Split token ids and remove duplicates
lens = [len(subject) for subject in subjects]
subjects_sets = [np.unique(s) for s in np.split(token_ids, np.cumsum(lens)[:-1])]
# Do the counting
_count_common_tokens_nb_inner(counts, subjects_sets)
return counts
@nb.njit(parallel=True)
def _count_common_tokens_nb_inner(counts, subjects_sets):
n = len(subjects_sets)
for i1 in nb.prange(n):
subj_1 = subjects_sets[i1]
for i2 in nb.prange(i1 + 1, n):
subj_2 = subjects_sets[i2]
c = 0
for t1 in subj_1:
for t2 in subj_2:
c += int(t1 == t2)
counts[i1, i2] = c
return counts
Вот быстрый тест и небольшое сравнение производительности.
import random
import string
import numpy as np
NUM_SUBJECTS = 1000
MAX_TOKENS_SUBJECT = 20
NUM_TOKENS = 5000
MAX_LEN_TOKEN = 10
# Make random input
random.seed(0)
letters = list(string.ascii_letters)
tokens = np.array(list(set(''.join(random.choices(letters, k=random.randint(1, MAX_LEN_TOKEN)))
for _ in range(NUM_TOKENS))))
subjects = [np.array(random.choices(tokens, k=random.randint(1, MAX_TOKENS_SUBJECT)))
for _ in range(NUM_SUBJECTS)]
# Do counts
res1 = count_common_tokens(subjects)
res2 = count_common_tokens_vec(subjects)
res3 = count_common_tokens_nb(subjects)
# Check results
print(np.all(np.triu(res1, 1) == np.triu(res2, 1)))
# True
print(np.all(np.triu(res1, 1) == np.triu(res3, 1)))
# True
# Test performance
%timeit count_common_tokens(subjects)
# 196 ms ± 2.15 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit count_common_tokens_vec(subjects)
# 5.09 s ± 30.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit count_common_tokens_nb(subjects)
# 65.2 ms ± 886 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Так что векторизованное решение не сработало, но с Numba вы значительно ускорились.