быстрая операция numpy для нарезки определенных ячеек из массива nd - PullRequest
1 голос
/ 27 мая 2019

После профилирования следующей функции load data я понял, что следующие строки являются основным узким местом:

 dist_1 = dist[random_labels, :][:, random_labels]  
 dist_2 = dist[other_random_labels, :][:, other_random_labels]

, где размер dist равен 6000,6000, а случайные метки имеют длину 5000.
Я пытаюсь использовать np.take но

np.take(dist_1,[random_labels,random_labels]) == dist_1[random_labels, :][:, random_labels]

- это False. где размер np.take(dist_1,[random_labels,random_labels]) равен (2,5000)
Есть ли эффективный способ сделать это в NumPy?
редактировать: это самый близкий у меня есть:

 dist_1 = np.take(np.take(dist, random_labels, axis=0), random_labels, axis=1)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...