Как сгенерировать смежные валюты с NumPy - PullRequest
1 голос
/ 06 апреля 2019

Поэтому я пытаюсь сгенерировать список возможных смежных движений в трехмерном массиве (предпочтительно n-мерный).

То, что у меня есть, работает так, как должно, но мне было интересно, есть ли более тупой способ сделать это.

def adjacents(loc, bounds):
    adj = []
    bounds = np.array(bounds) - 1

    if loc[0] > 0:
        adj.append((-1, 0, 0))
    if loc[1] > 0:
        adj.append((0, -1, 0))
    if loc[2] > 0:
        adj.append((0, 0, -1))

    if loc[0] < bounds[0]:
        adj.append((1, 0, 0))
    if loc[1] < bounds[1]:
        adj.append((0, 1, 0))
    if loc[2] < bounds[2]:
        adj.append((0, 0, 1))

    return np.array(adj)

Вот несколько примеров выходных данных:

adjacents((0, 0, 0), (10, 10, 10)) 

= [[1 0 0]
   [0 1 0]
   [0 0 1]]

adjacents((9, 9, 9), (10, 10, 10))

= [[-1  0  0]
   [ 0 -1  0]
   [ 0  0 -1]]

adjacents((5, 5, 5), (10, 10, 10))

= [[-1  0  0]
   [ 0 -1  0]
   [ 0  0 -1]
   [ 1  0  0]
   [ 0  1  0]
   [ 0  0  1]]

1 Ответ

2 голосов
/ 06 апреля 2019

Вот альтернатива, которая векторизована и использует постоянный, предварительно заполненный массив:

# all possible moves
_moves = np.array([
        [-1, 0, 0],
        [ 0,-1, 0],
        [ 0, 0,-1],
        [ 1, 0, 0],
        [ 0, 1, 0],
        [ 0, 0, 1]])

def adjacents(loc, bounds):
    loc = np.asarray(loc)
    bounds = np.asarray(bounds)
    mask = np.concatenate((loc > 0, loc < bounds - 1))
    return _moves[mask]

Используется asarray() вместо array(), потому что это позволяет избежать копирования, если входные данные уже являются массивом. Затем mask строится как массив из шести bools, соответствующих исходным шести if условиям. Наконец, возвращаются соответствующие строки постоянных данных _moves.

А как насчет производительности?

Векторизованный подход выше, хотя и понравится некоторым, на самом деле работает только вдвое быстрее, чем оригинал. Если вам нужна производительность, лучшее простое изменение, которое вы можете сделать, это удалить строку bounds = np.array(bounds) - 1 и вычесть 1 внутри каждого из трех последних условий if. Это дает вам двукратное ускорение (потому что это позволяет избежать создания ненужного массива).

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...