Numpy - фильтрация массива на основе индексов и соответствующей оси - PullRequest
0 голосов
/ 31 октября 2019

У меня есть целочисленный массив ids и массив с плавающей точкой distances_per_batch:

BATCH_SIZE = 500
ARRAY_SIZE = 10000

ids = np.arange(ARRAY_SIZE) # Shape = ARRAY_SIZE,
distances_per_batch = np.random.rand(BATCH_SIZE, ARRAY_SIZE) # Shape = BATCH_SIZE, ARRAY_SIZE

Я пытаюсь получить идентификаторы, где их расстояние больше 0,9:

ids_expanded = np.repeat(np.expand_dims(ids, axis=0), BATCH_SIZE, axis=0) # Shape = BATCH_SIZE, ARRAY_SIZE (Not sure if this is even right to use since it takes a while for larger BATCH_SIZE & ARRAY_SIZE and seems to create a new array
selected_ids = ids_expanded[distances_per_batch > 0.9]

Я ожидаю, что selected_ids будет иметь двумерную форму (500,?), чтобы получить идентификаторы, которые имеют расстояние больше 0,9 для каждой записи в пакете (всего 500 записей), но конечный результат автоматически изменяетсяв одномерный массив, и я не могу решить, какой выбранный идентификатор принадлежит к какой из 500 записей ...

Как быстро и правильно получить желаемые результаты (не просматривая каждую записьодин за другим и используя более быстрые методы Numpy)? Я даже не уверен, что расширение размеров и повторение массива - правильный путь, поскольку для увеличения BATCH_SIZE & ARRAY_SIZE требуется некоторое время, и создается впечатление, что он создает новый массив.

1 Ответ

1 голос
/ 31 октября 2019
np.where(distances_per_batch > 0.9) 

возвращает отдельные массивы индексов строк и столбцов. Чтобы собрать их

np.transpose(np.where(distances_per_batch > 0.9))

с любыми случайными данными, он возвращает

array([[   0,    0],
       [   0,   31],
       [   0,   33],
       ..., 
       [ 499, 9988],
       [ 499, 9993],
       [ 499, 9995]], dtype=int32)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...