Самый быстрый способ сортировки в Python (без Cython) - PullRequest
0 голосов
/ 07 мая 2018

У меня проблема с сортировкой очень большого массива (форма - 7900000X4X4) с пользовательской функцией. Я использовал sorted, но сортировка заняла более 1 часа. Мой код был примерно таким.

def compare(x,y):
    print('DD '+str(x[0]))
    if(np.array_equal(x[1],y[1])==True):
        return -1
    a = x[1].flatten()
    b = y[1].flatten()
    idx = np.where( (a>b) != (a<b) )[0][0]
    if a[idx]<0 and b[idx]>=0:
        return 0
    elif b[idx]<0 and a[idx]>=0:
        return 1
    elif a[idx]<0 and b[idx]<0:
        if a[idx]>b[idx]:
            return 0
        elif a[idx]<b[idx]:
            return 1
    elif a[idx]<b[idx]:
        return 1
    else:
        return 0
def cmp_to_key(mycmp):
    class K:
        def __init__(self, obj, *args):
            self.obj = obj
        def __lt__(self, other):
            return mycmp(self.obj, other.obj)
    return K
tblocks = sorted(tblocks.items(),key=cmp_to_key(compare))

Это сработало, но я хочу, чтобы оно завершилось за несколько секунд. Я не думаю, что какая-либо прямая реализация в Python может дать мне необходимую производительность, поэтому я попробовал Cython. Вот мой код на Cython, который довольно прост.

cdef int[:,:] arrr
cdef int size

cdef bool compare(int a,int b):
    global arrr,size
    cdef int[:] x = arrr[a]
    cdef int[:] y = arrr[b]
    cdef int i,j
    i = 0
    j = 0
    while(i<size):
        if((j==size-1)or(y[j]<x[i])):
            return 0
        elif(x[i]<y[j]):
            return 1
        i+=1
        j+=1
    return (j!=size-1)

def sorted(np.ndarray boxes,int total_blocks,int s):
    global arrr,size
    cdef int i
    cdef vector[int] index = xrange(total_blocks)
    arrr = boxes
    size = s
    sort(index.begin(),index.end(),compare)
    return index

Этот код в Cython занял 33 секунды! Cython - это решение, но я ищу альтернативные решения, которые могут работать непосредственно на Python. Например, нумба. Я попробовал Numba, но я не получил удовлетворительных результатов. Пожалуйста, помогите!

Ответы [ 2 ]

0 голосов
/ 07 мая 2018

Если я правильно понимаю ваш код, то порядок, который вы имеете в виду, является стандартным, только то, что он начинается с 0, охватывает +/-infinity и достигает максимума -0. Кроме того, у нас есть простой лексикографический порядок слева направо.

Теперь, если ваш массив dtype является целым числом, обратите внимание на следующее: Из-за представления представления отрицательных значений приведение типов к unsigned int делает ваш заказ стандартным. Вдобавок к этому, если мы используем кодирование с прямым порядком байтов, эффективное лексикографическое упорядочение может быть достигнуто путем приведения к виду void dtype.

Приведенный ниже код показывает, что на примере 10000x4x4 этот метод дает тот же результат, что и ваш код Python.

Он также тестирует его на примере 7,900,000x4x4 (используя массив, а не dict). На моем скромном ноутбуке этот метод занимает 8 секунд.

import numpy as np

def compare(x, y):
#    print('DD '+str(x[0]))
    if(np.array_equal(x[1],y[1])==True):
        return -1
    a = x[1].flatten()
    b = y[1].flatten()
    idx = np.where( (a>b) != (a<b) )[0][0]
    if a[idx]<0 and b[idx]>=0:
        return 0
    elif b[idx]<0 and a[idx]>=0:
        return 1
    elif a[idx]<0 and b[idx]<0:
        if a[idx]>b[idx]:
            return 0
        elif a[idx]<b[idx]:
            return 1
    elif a[idx]<b[idx]:
        return 1
    else:
        return 0
def cmp_to_key(mycmp):
    class K:
        def __init__(self, obj, *args):
            self.obj = obj
        def __lt__(self, other):
            return mycmp(self.obj, other.obj)
    return K

def custom_sort(a):
    assert a.dtype==np.int64
    b = a.astype('>i8', copy=False)
    return b.view(f'V{a.dtype.itemsize * a.shape[1]}').ravel().argsort()

tblocks = np.random.randint(-9,10, (10000, 4, 4))
tblocks = dict(enumerate(tblocks))

tblocks_s = sorted(tblocks.items(),key=cmp_to_key(compare))

tblocksa = np.array(list(tblocks.values()))
tblocksa = tblocksa.reshape(tblocksa.shape[0], -1)
order = custom_sort(tblocksa)
tblocks_s2 = list(tblocks.items())
tblocks_s2 = [tblocks_s2[o] for o in order]

print(tblocks_s == tblocks_s2)

from timeit import timeit

data = np.random.randint(-9_999, 10_000, (7_900_000, 4, 4))

print(timeit(lambda: data[custom_sort(data.reshape(data.shape[0], -1))],
             number=5) / 5)

Пример вывода:

True
7.8328493310138585
0 голосов
/ 07 мая 2018

Трудно дать ответ без рабочего примера. Я предполагаю, что arrr в вашем коде Cython был 2D-массивом, и я предполагаю, что размер был size=arrr.shape[0]

Реализация Numba

import numpy as np
import numba as nb
from numba.targets import quicksort


def custom_sorting(compare_fkt):
  index_arange=np.arange(size)

  quicksort_func=quicksort.make_jit_quicksort(lt=compare_fkt,is_argsort=False)
  jit_sort_func=nb.njit(quicksort_func.run_quicksort)
  index=jit_sort_func(index_arange)

  return index

def compare(a,b):
    x = arrr[a]
    y = arrr[b]
    i = 0
    j = 0
    while(i<size):
        if((j==size-1)or(y[j]<x[i])):
            return False
        elif(x[i]<y[j]):
            return True
        i+=1
        j+=1
    return (j!=size-1)


arrr=np.random.randint(-9,10,(7900000,8))
size=arrr.shape[0]

index=custom_sorting(compare)

Это дает 3,85 для сгенерированных тестовых данных. Но скорость алгоритма сортировки сильно зависит от данных ....

Простой пример

import numpy as np
import numba as nb
from numba.targets import quicksort

#simple reverse sort
def compare(a,b):
  return a > b

#create some test data
arrr=np.array(np.random.rand(7900000)*10000,dtype=np.int32)
#we can pass the comparison function
quicksort_func=quicksort.make_jit_quicksort(lt=compare,is_argsort=True)
#compile the sorting function
jit_sort_func=nb.njit(quicksort_func.run_quicksort)
#get the result
ind_sorted=jit_sort_func(arrr)

Эта реализация примерно на 35% медленнее, чем np.argsort, но это также распространено при использовании np.argsort в скомпилированном коде.

...