Частичные срезы в pytorch / numpy с произвольным и переменным числом измерений - PullRequest
1 голос
/ 11 июня 2019

Учитывая двумерный тензор в numpy (или в pytorch), я могу частично срезать по всем измерениям одновременно следующим образом:

>>> import numpy as np
>>> a = np.arange(2*3).reshape(2,3)
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])
>>> a[1:,1:]
array([[ 5,  6,  7],
       [ 9, 10, 11]])

Как я могу получить один и тот же шаблон нарезки независимо от числаразмеров в тензоре, если я не знаю количество измерений во время реализации?(т.е. я хочу a[1:], если a имеет только одно измерение, a[1:,1:] для двух измерений, a[1:,1:,1:] для трех измерений и т. д.)

Было бы хорошо, если бы я мог сделать это водна строка кода, подобная следующей, но это неверно:

a[(1:,) * len(a.shape)]  # SyntaxError: invalid syntax

Меня особенно интересует решение, которое работает для тензоров pytorch (просто замените torch на numpy выше, и пример такой же),но я полагаю, что это вероятно и лучше всего, если решение будет работать как для numpy, так и для pytorch.

1 Ответ

1 голос
/ 11 июня 2019

Ответ: Создание кортежа из объектов slice делает свое дело:

a[(slice(1,None),) * len(a.shape)]

Пояснение: slice - это встроенный класс Python (не привязанный к numpy или pytorch), который предоставляет альтернативу нижнему индексу для описания слайсов. Ответ на другой вопрос предлагает использовать это как способ хранения информации среза в переменных Python. Глоссарий Python указывает, что

Обозначение в скобках (нижний индекс) использует slice объекты внутри.

Так как методы __getitem__ для numpy ndarrays и тензоров pytorch поддерживают многомерное индексирование с помощью срезов, они также должны поддерживать многомерное индексирование с помощью объектов срезов, и поэтому мы можем сделать кортеж из этих ломтиков правильной длины.

Кстати, вы можете увидеть, как python использует объекты срезов, создав фиктивный класс следующим образом, а затем выполнить нарезку на него:

class A(object):
    def __getitem__(self, ix):
        return ix

print(A()[5])  # 5
print(A()[1:])  # slice(1, None, None)
print(A()[1:,1:])  # (slice(1, None, None), slice(1, None, None))
print(A()[1:,slice(1,None)])  #  (slice(1, None, None), slice(1, None, None))


Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...