После профилирования следующей функции load data
я понял, что следующие строки являются основным узким местом:
dist_1 = dist[random_labels, :][:, random_labels]
dist_2 = dist[other_random_labels, :][:, other_random_labels]
, где размер dist
равен 6000,6000
, а случайные метки имеют длину 5000
.
Я пытаюсь использовать np.take
но
np.take(dist_1,[random_labels,random_labels]) == dist_1[random_labels, :][:, random_labels]
- это False
.
где размер np.take(dist_1,[random_labels,random_labels])
равен (2,5000)
Есть ли эффективный способ сделать это в NumPy?
редактировать: это самый близкий у меня есть:
dist_1 = np.take(np.take(dist, random_labels, axis=0), random_labels, axis=1)