Перестановка массива NumPy вдоль заданной оси - PullRequest
15 голосов
/ 18 февраля 2011

Учитывая следующий массив NumPy,

> a = array([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5],[1, 2, 3, 4, 5]])

достаточно просто перетасовать один ряд,

> shuffle(a[0])
> a
array([[4, 2, 1, 3, 5],[1, 2, 3, 4, 5],[1, 2, 3, 4, 5]])

Можно ли использовать индексную нотацию для перемешивания каждой строки независимо? Или вам нужно перебирать массив. Я имел в виду что-то вроде

> numpy.shuffle(a[:])
> a
array([[4, 2, 3, 5, 1],[3, 1, 4, 5, 2],[4, 2, 1, 3, 5]]) # Not the real output

хотя это явно не работает.

Ответы [ 2 ]

19 голосов
/ 18 февраля 2011

Вы должны позвонить numpy.random.shuffle() несколько раз, потому что вы перетасовываете несколько последовательностей независимо. numpy.random.shuffle() работает с любой изменяемой последовательностью и на самом деле не является ufunc. Самый короткий и эффективный код для перестановки всех строк двумерного массива a по отдельности, вероятно, равен

map(numpy.random.shuffle, a)
2 голосов
/ 23 марта 2019

Векторизованное решение с rand+argsort уловкой

Мы могли бы генерировать уникальные индексы вдоль указанной оси и индексировать во входной массив с помощью advanced-indexing.Для генерации уникальных индексов мы бы использовали random float generation + sort трюк , что дает нам векторизованное решение.Мы также обобщили бы это, чтобы охватить универсальные n-dim массивы и вместе с универсальными axes с np.take_along_axis.Окончательная реализация будет выглядеть примерно так -

def shuffle_along_axis(a, axis):
    idx = np.random.rand(*a.shape).argsort(axis=axis)
    return np.take_along_axis(a,idx,axis=axis)

Обратите внимание, что этот случайный порядок не будет на месте и возвращает перемешанную копию.

Пример выполнения -

In [33]: a
Out[33]: 
array([[18, 95, 45, 33],
       [40, 78, 31, 52],
       [75, 49, 42, 94]])

In [34]: shuffle_along_axis(a, axis=0)
Out[34]: 
array([[75, 78, 42, 94],
       [40, 49, 45, 52],
       [18, 95, 31, 33]])

In [35]: shuffle_along_axis(a, axis=1)
Out[35]: 
array([[45, 18, 33, 95],
       [31, 78, 52, 40],
       [42, 75, 94, 49]])
...