import numpy as np
import scipy.sparse as sparse
import random
randint = random.randint
def orig(W, n):
result = list()
while len(result) < n:
r = randint(0, W.shape[0]-1)
c = randint(0, W.shape[1]-1)
if W[r,c] == 0:
result.append((r,c))
return result
def alt(W, n):
nrows, ncols = W.shape
density = n / (nrows*ncols - W.count_nonzero())
W = W.copy()
W.data[:] = 1
W2 = sparse.csr_matrix((nrows, ncols))
while W2.count_nonzero() < n:
W2 += sparse.random(nrows, ncols, density=density, format='csr')
# remove nonzero values from W2 where W is 1
W2 -= W2.multiply(W)
W2 = W2.tocoo()
r = W2.row[:n]
c = W2.col[:n]
result = list(zip(r, c))
return result
def alt_with_dupes(W, n):
nrows, ncols = W.shape
density = n / (nrows*ncols - W.count_nonzero())
W = W.copy()
W.data[:] = 1
W2 = sparse.csr_matrix((nrows, ncols))
while W2.data.sum() < n:
tmp = sparse.random(nrows, ncols, density=density, format='csr')
tmp.data[:] = 1
W2 += tmp
# remove nonzero values from W2 where W is 1
W2 -= W2.multiply(W)
W2 = W2.tocoo()
num_repeats = W2.data.astype('int')
r = np.repeat(W2.row, num_repeats)
c = np.repeat(W2.col, num_repeats)
idx = np.random.choice(len(r), n)
result = list(zip(r[idx], c[idx]))
return result
Вот эталонный тест:
W = sparse.random(1000, 50000, density=0.02, format='csr')
n = int((np.multiply(*W.shape) - W.nnz)*0.01)
In [194]: %timeit alt(W, n)
809 ms ± 261 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [195]: %timeit orig(W, n)
11.2 s ± 121 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [223]: %timeit alt_with_dupes(W, n)
986 ms ± 290 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Обратите внимание, что alt
возвращает список без дубликатов.И orig
, и alt_with_dupes
могут возвращать дубликаты.