Получение последних элементов по оси в массиве numpy - PullRequest
0 голосов
/ 01 августа 2020

Мне нужна функция для получения последних элементов по оси в массиве numpy.

Например, если у меня есть массив,

a = np.array([1, 2, 3])

Функция должна работать как

get_last_elements(a, axis=0)
>>> [3]
get_last_elements(a, axis=1)
>>> [1, 2, 3]

Эта функция должна работать и для многомерного массива:

b = np.array([[1, 2],
              [3, 4]])

get_last_elements(b, axis=0)
>>> [[2],
     [4]]
get_last_elements(b, axis=1)
>>> [3, 4]

Есть ли у кого-нибудь хорошая идея для этого?

1 Ответ

1 голос
/ 02 августа 2020

Вы можете использовать np.take, чтобы получить это:

def get_last_elements(a, axis=0):
  shape = list(a.shape)
  shape[axis] = 1
  return np.take(a,-1,axis=axis).reshape(tuple(shape))

вывод:

print(get_last_elements(b, axis=0))
[[3 4]]

print(get_last_elements(b, axis=1))
[[2]
 [4]]
...