Как сделать маскировку в PyTorch / Numpy с разными размерами? - PullRequest
1 голос
/ 05 апреля 2020

У меня есть mask с размером torch.Size([20, 1, 199]) и тензор, reconstruct_output и inputs с размером torch.Size([20, 1, 161, 199]).

Я хочу установить reconstruct_output в inputs, где mask равно 0. Я пытался:

reconstruct_output[mask == 0] = inputs[mask == 0]

Но я получаю сообщение об ошибке:

IndexError: The shape of the mask [20, 1, 199] at index 2 does not match the shape of the indexed tensor [20, 1, 161, 199] at index 2

1 Ответ

2 голосов
/ 05 апреля 2020

Мы можем использовать advanced indexing здесь. Чтобы получить индексные массивы, которые мы хотим использовать для индексирования reconstruct_output и inputs, нам нужны индексы вдоль его осей, где m==0. Для этого мы можем использовать np.where и использовать полученные индексы для обновления reconstruct_output как:

m = mask == 0
i, _, l = np.where(m)
reconstruct_output[i, ..., l] = inputs[i, ..., l]

Вот небольшой пример, с которым я проверял:

mask = np.random.randint(0,3, (2, 1, 4))
reconstruct_output = np.random.randint(0,10, (2, 1, 3, 4))
inputs = np.random.randint(0,10, (2, 1, 3, 4))

Предоставление, например:

print(reconstruct_output)

array([[[[8, 9, 7, 2],
         [5, 4, 6, 1],
         [1, 4, 0, 3]]],


       [[[4, 3, 3, 4],
         [0, 9, 9, 7],
         [3, 4, 9, 3]]]])

print(inputs)

array([[[[7, 3, 9, 8],
         [3, 1, 0, 8],
         [0, 5, 4, 8]]],


       [[[3, 7, 5, 8],
         [2, 5, 3, 8],
         [3, 6, 7, 5]]]])

И mask:

print(mask)

array([[[0, 1, 2, 1]],

       [[1, 0, 1, 0]]])

Используя np.where, чтобы найти индексы, где есть нули в mask мы получим:

m = mask == 0
i, _, l = np.where(m)

i
# array([0, 1, 1])

l
# array([0, 1, 3])

Следовательно, мы заменим 0-й столбец из первого 2D-массива и 1-й и 3-й из второго 2D-массива.

Теперь мы можем используйте эти массивы для замены вдоль соответствующих осей индексации как:

reconstruct_output[i, ..., l] = inputs[i, ..., l]

Получение:

reconstruct_output

array([[[[7, 9, 7, 2],
         [3, 4, 6, 1],
         [0, 4, 0, 3]]],


       [[[4, 7, 3, 8],
         [0, 5, 9, 8],
         [3, 6, 9, 5]]]])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...