Нарезка четырехмерного тензора с трехмерным тензорным индексом в PyTorch - PullRequest
1 голос
/ 20 октября 2019

У меня есть 4D тензор (который является стеком из трех партий изображений 56x56, где каждая партия имеет 16 изображений) с размером [16,3, 56, 56] . Моя цель состоит в том, чтобы выбрать правильный из этих трех пакетов (с моей индексной картой, которая имеет размер [16, 56, 56] ) для каждого пикселя и получить изображения, которые я хочу.

Теперь я хочу выбрать определенные пакеты изображений внутри этих трех пакетов, со значением, которое имеет значения, такие как

       [[[ 0,  0,  2,  ...,  0,  0,  0],
         [ 0,  0,  2,  ...,  0,  0,  0],
         [ 0,  0,  0,  ...,  0,  0,  0],
         ...,
         [ 0,  0,  0,  ...,  0,  0,  0],
         [ 0,  2,  0,  ...,  0,  0,  0],
         [ 0,  2,  2,  ...,  0,  0,  0]],

        [[ 0,  2,  0,  ...,  1,  1,  0],
         [ 0,  2,  0,  ...,  0,  0,  0],
         [ 0,  0,  0,  ...,  0,  2,  0],
         ...,
         [ 0,  0,  0,  ...,  0,  2,  0],
         [ 0,  0,  2,  ...,  0,  2,  0],
         [ 0,  0,  2,  ...,  0,  0,  0]]]

Так что для 0, значение будетвыбран из первого пакета, где 1 и 2 означают, что я хочу выбрать значения из второго и третьего пакетов.

Вот некоторые из визуализаций индексов, каждый цвет обозначает другой пакет.

enter image description here

enter image description here

Я попытался транспонировать 4-мерный тензор, чтобы он соответствовал размерам моих индексов,но это не сработало. Все, что он делает, это дает мне копию измерений, которые я пытался выбрать. Означает

tposed = torch.transpose(fourD, 0,1) print(indices.size(),
outs.size(), tposed[:, indices].size())

вывод

torch.Size([16, 56, 56]) torch.Size([16, 3, 56, 56]) torch.Size([3, 16, 56, 56, 56, 56])

, в то время как мне нужна форма

torch.Size([16, 56, 56]) or torch.Size([16, 1, 56, 56])

и в качестве примера, если я пытаюсь выбрать правильные значения только дляпервое изображение в пакете с

fourD[0,indices].size()

Я получаю форму, подобную

torch.Size([16, 56, 56, 56, 56])

Не говоря уже о том, что у меня возникает ошибка нехватки памяти, когда я пробую это на всем тензоре.

Я ценю любую помощь для , использующего эти индексы для выбора одного из этих трех пакетов для каждого пикселя в моих изображениях .

Примечание:

Iпопробовал опцию

outs[indices[:,None,:,:]].size()

и возвращает

torch.Size([16, 1, 56, 56, 3, 56, 56])

Редактировать: torch.take не очень помогает, так как он обрабатывает входной тензор как одномерный массив.

1 Ответ

1 голос
/ 21 октября 2019

Оказывается, в PyTorch есть функция, которая обладает той функцией, которую я искал.

torch.gather(fourD, 1, indices.unsqueeze(1)) 

сделал работу.

Здесь - прекрасное объяснение того, что делает сбор.

...