Производительность сортировки структурированных массивов (numpy) - PullRequest
1 голос
/ 03 апреля 2019

У меня есть массив с несколькими полями, которые я хочу отсортировать по двум из них.Одно из этих полей является двоичным, например:

size = 100000
data = np.empty(
            shape=2 * size,
            dtype=[('class', int),
                   ('value', int),]
)

data['class'][:size] = 0
data['value'][:size] = (np.random.normal(size=size) * 10).astype(int)
data['class'][size:] = 1
data['value'][size:] = (np.random.normal(size=size, loc=0.5) * 10).astype(int)

np.random.shuffle(data)

Мне нужно отсортировать результат по value, и для тех же значений class=0 должно идти первым.Делая это так (a) :

idx = np.argsort(data, order=['value', 'class'])
data_sorted = data[idx]

кажется на порядок медленнее по сравнению с сортировкой data['value'].Есть ли способ улучшить скорость, учитывая, что есть только два класса?

Экспериментируя случайно, я заметил, что такой подход (b) :

idx = np.argsort(data['value'])
data_sorted = data[idx]
idx = np.argsort(data_sorted, order=['value', 'class'], kind='mergesort')
data_sorted = data_sorted[idx]

занимает на 20% меньше времени, чем (a) .Изменение типов данных поля, по-видимому, также имеет некоторый эффект - плавающие вместо целых, кажется, немного быстрее.

1 Ответ

1 голос
/ 03 апреля 2019

Самый простой способ сделать это - использовать параметр order, равный sort

sort(data, order=['value', 'class'])

Однако для запуска на моем компьютере требуется 121 мс, в то время как для data['class'] и data['value'] требуется всего 2,44 и 5,06 мс соответственно. Интересно, что sort(data, order='class') снова занимает 135 мс, что указывает на проблему с сортировкой структурированных массивов.

Итак, вы выбрали сортировку каждого поля с использованием argsort, а затем индексирование окончательного массива, похоже, на правильном пути. Однако вам нужно отсортировать каждое поле отдельно,

idx=argsort(data['class'])
data_sorted = data[idx][argsort(data['value'][idx], kind='stable')]

Это выполняется за 43,9 мс. Вы можете получить очень небольшое ускорение, удалив один временный массив из индексации

idx = argsort(data['class'])
tmp = data[idx]
data_sorted = tmp[argsort(tmp['value'], kind='stable')]

Который работает за 40,8 мс. Не здорово, но это обходной путь, если производительность критична.

Кажется, это известная проблема: очень медленная сортировка структурированных и записанных массивов

Редактировать Исходный код для сравнений, используемых в сортировке, можно увидеть по https://github.com/numpy/numpy/blob/dea85807c258ded3f75528cce2a444468de93bc1/numpy/core/src/multiarray/arraytypes.c.src. Числовые типы намного, намного проще. Тем не менее, такая большая разница в производительности удивительна.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...