Использование словарей в Cython медленнее, чем тривиальные циклы - PullRequest
0 голосов
/ 12 февраля 2020

Вот мой код Cython:

import pickle
with open('ds.pickle', 'rb') as var:
    ds =  pickle.load(var) #list of dicts
d1 = ds[0]  #dictionary containing left shift of all possible tuples of size 4, having elems from 0 to 2^20, 2's powers
d2 = ds[1] #dictionary containing right shift of all possible tuples of size 4, having elems from 0 to 2^20, 2's powers

@cython.boundscheck(False)
@cython.wraparound(False)
def l(np.ndarray grid):
    cdef np.ndarray l1=grid.copy()
    cdef int i
    for i in range(4):
        l1[i] = d1[tuple(l1[i])]
    return l1

@cython.boundscheck(False)
@cython.wraparound(False)
def r(np.ndarray grid):
    cdef np.ndarray l1 = grid.copy()
    cdef int i
    for i in range(4):
        l1[i] = d2[tuple(l1[i])]
    return l1

@cython.boundscheck(False)
@cython.wraparound(False)
def u(np.ndarray grid):
    cdef np.ndarray l1 = grid.copy()
    cdef int i
    for i in range(4):
        l1[:,i] = d1[tuple(l1[:,i])]
    return l1

@cython.boundscheck(False)
@cython.wraparound(False)
def d(np.ndarray grid):
    cdef np.ndarray l1 = grid.copy()
    cdef int i
    for i in range(4):
        l1[:,i] = d2[tuple(l1[:,i])]
    return l1

@cython.boundscheck(False)
@cython.wraparound(False)
def c(np.ndarray grid, int move):
    if move == 2: return l(grid)
    if move == 0: return u(grid)
    if move == 1: return d(grid)
    if move == 3: return r(grid)

Это код для перемещения игровой сетки игры 2048 влево-вправо вверх. Создание случайных тайлов выполняется позже.

Вот такой тривиальный код, но он работает быстрее, чем приведенный выше код, даже если он вычисляет решение на go и не использует расширенный поиск, как предложения словаря.

@cython.boundscheck(False)
def left(np.ndarray grid):
    #assumption: grid is 4 x 4 numpy matrix 
    cdef np.ndarray l = grid.copy()
    cdef int j, i, p, merged;
    cdef long t;
    cdef list res;
    for j in range(4):
        res = [];
        merged = 0
        for i in range(4):
            t = l[j][-i-1]
            if t == 0: continue
            if res and t == res[-1] and merged == 0:
                res[-1]+=t
                merged = 1
            else:
                if res: merged = 0
                res+=[t]
        for p in range(4-len(res)): res = [0]+res
        l[j] = res[::-1]
        #l[j][0], l[j][1], l[j][2], l[j][3] = res[3], res[2], res[1], res[0]
    return l

@cython.boundscheck(False)
def right(np.ndarray grid):
    cdef np.ndarray l = grid.copy()
    cdef int j, i, p, merged;
    cdef long t;
    cdef list res;
    for j in range(4):
        res = []
        merged = 0
        for i in range(4):
            t = l[j][i]
            if t == 0: continue
            if res and t == res[-1] and merged == 0:
                res[-1]+=t
                merged = 1
            else:
                if res: merged = 0
                res+=[t]
        for p in range(4-len(res)): res = [0]+res
        l[j] = res
    return l

@cython.boundscheck(False)
def down(np.ndarray grid):
    cdef np.ndarray l = grid.copy()
    cdef int j, i, p, merged;
    cdef long t;
    cdef list res;
    for j in range(4):
        res = []
        merged = 0
        for i in range(4):
            t = l[i][j]
            if t == 0: continue
            if res and t == res[-1] and merged == 0:
                res[-1]+=t
                merged = 1
            else:
                if res: merged = 0
                res+=[t]
        for p in range(4-len(res)): res=[0]+res
        l[:, j] = res
    return l

@cython.boundscheck(False)
def up(np.ndarray grid):
    cdef np.ndarray l = grid.copy()
    cdef int j, i, p, merged;
    cdef long t;
    cdef list res;
    for j in range(4):
        res = []
        merged = 0
        for i in range(4):
            t = l[-i-1][j]
            if t == 0: continue
            if res and t == res[-1] and merged == 0:
                res[-1]+=t
                merged = 1
            else:
                if res: merged = 0
                res+=[t]
        for p in range(4-len(res)): res=[0]+res
        l[:, j] = res[::-1]
    return l

@cython.boundscheck(False)
@cython.wraparound(False)
def c(np.ndarray grid, int move):
    if move == 2: return left(grid)
    if move == 0: return up(grid)
    if move == 1: return down(grid)
    if move == 3: return right(grid)

Есть ли способ быстрее использовать эти решения для кеширования в Cython? или как улучшить мой первый пример кода?

...