Есть ли другой способ сортировки массива 3D numpy по значениям столбцов? - PullRequest
1 голос
/ 27 апреля 2020

Я написал код python для сортировки набора из четырех таблиц 3x3 по значению их первых столбцов. Есть ли более простой способ сделать это с меньшим количеством кода и, возможно, более эффективным? Вот мой код:

import numpy as np 

np.random.seed(4)

a = np.random.randint(10, size=(4, 3, 3))
ind = a[:,:,0].argsort()
ind = np.stack(a.shape[2]*[ind], axis=1)
b = np.take_along_axis(a.transpose(0, 2, 1), ind, axis=2).transpose(0, 2, 1)

print(a)
print("----------------")
print(b)
[[[7 5 1]
  [8 7 8]
  [2 9 7]]

 [[7 7 9]
  [8 4 2]
  [6 4 3]]

 [[0 7 5]
  [5 9 6]
  [6 8 2]]

 [[5 8 1]
  [2 7 0]
  [8 3 1]]]
----------------
[[[2 9 7]
  [7 5 1]
  [8 7 8]]

 [[6 4 3]
  [7 7 9]
  [8 4 2]]

 [[0 7 5]
  [5 9 6]
  [6 8 2]]

 [[2 7 0]
  [5 8 1]
  [8 3 1]]]

1 Ответ

0 голосов
/ 27 апреля 2020

У вас есть хорошее решение для этого. Вот более короткий (и, вероятно, более быстрый) вариант:

b = np.einsum('iijk->ijk', a[:,a[:,:,0].argsort()])

einsum в основном делает то, что вы пытаетесь достичь с помощью индексации. Он принимает i-th элемент i-th элемент a[:,a[:,:,0].argsort().

b:

[[[2 9 7]
  [7 5 1]
  [8 7 8]]

 [[6 4 3]
  [7 7 9]
  [8 4 2]]

 [[0 7 5]
  [5 9 6]
  [6 8 2]]

 [[2 7 0]
  [5 8 1]
  [8 3 1]]]
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...