Копирование данных из одного тензора в другой с использованием битовой маскировки - PullRequest
0 голосов
/ 17 декабря 2018
import numpy as np
import torch
a = torch.zeros(5)
b = torch.tensor(tuple((0,1,0,1,0)),dtype=torch.uint8)
c= torch.tensor([7.,9.])
print(a[b].size())
a[b]=c
print(a)

факел. Размер ([2])
тензор ([0, 7, 0., 9., 0.])

Я естьпытается понять, как это работает.Сначала я думал, что в приведенном выше коде используется индексирование Fancy, но я понял, что копируются значения из тензоров c , соответствующие индексам, помеченным 1. Кроме того, если я не укажу dtype для b as uint8 , тогда приведенный выше код не работает.Может кто-нибудь, пожалуйста, объясните мне механизм вышеуказанного кода.

1 Ответ

0 голосов
/ 17 декабря 2018

Индексирование с помощью массивов работает так же, как в numpy и большинстве других векторизованных математических пакетов, о которых я знаю.Есть два случая:

  1. Когда b имеет тип uint8 (думаю, логическое значение, pytorch не различает bool от uint8), a[b] является1-й массив, содержащий подмножество значений a (a[i]), для которого соответствующее значение в b (b[i]) было ненулевым.Эти значения привязаны к исходному a, поэтому, если вы измените их, их соответствующие местоположения также изменятся.

  2. Альтернативный тип, который вы можете использовать для индексации, - это массив * 1021.*, в этом случае a[b] создает массив формы (*b.shape, *a.shape[1:]).Его структура такая, как если бы каждый элемент b (b[i]) был заменен на a[i].Другими словами, вы создаете новый массив, указывая, из каких индексов a следует извлекать данные.Опять же, значения привязаны к исходному a, поэтому если вы измените a[b], значения a[b[i]] для каждого i будут меняться.Пример использования показан в этом вопросе.

Эти два режима описаны для numpy в индексации целочисленного массива и логического массиваиндексирование , где для последнего вы должны помнить, что pytorch использует uint8 вместо bool.

Кроме того, если ваша цель - скопировать данные из одного тензора в другой, у вас естьпомнить, что такая операция, как a[ixs] = b[ixs], является операцией на месте (a изменяется на месте), что не очень хорошо работает с autograd.Если вы хотите сделать маскировку вне места, используйте torch.where.Пример использования показан в этом ответе.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...