Перемещение тензоров элементов с помощью module.to () в PyTorch - PullRequest
0 голосов
/ 15 февраля 2019

Я строю Variational Autoencoder (VAE) в PyTorch и у меня проблемы с написанием кода, не зависящего от устройства.Автоэнкодер является потомком nn.Module с сетью кодировщика и декодера, которые тоже.Все веса сети могут быть перемещены с одного устройства на другое путем вызова net.to(device).

Проблема, с которой я столкнулся, заключается в уловке репараметризации:

encoding = mu + noise * sigma

Шум представляет собой тензортот же размер, что и mu и sigma и сохранен как переменная-член модуля автоэнкодера.Он инициализируется в конструкторе и пересэмплируется на каждом этапе обучения.Я делаю это таким образом, чтобы не создавать новый тензор шума на каждом шаге и не выдвигать его на нужное устройство.Кроме того, я хочу исправить шум в оценке.Вот код:

class VariationalGenerator(nn.Module):
    def __init__(self, input_nc, output_nc):
        super(VariationalGenerator, self).__init__()

        self.input_nc = input_nc
        self.output_nc = output_nc
        embedding_size = 128

        self._train_noise = torch.randn(batch_size, embedding_size)
        self._eval_noise = torch.randn(1, embedding_size)
        self.noise = self._train_noise

        # Create encoder
        self.encoder = Encoder(input_nc, embedding_size)
        # Create decoder
        self.decoder = Decoder(output_nc, embedding_size)

    def train(self, mode=True):
        super(VariationalGenerator, self).train(mode)
        self.noise = self._train_noise

    def eval(self):
        super(VariationalGenerator, self).eval()
        self.noise = self._eval_noise

    def forward(self, inputs):
        # Calculate parameters of embedding space
        mu, log_sigma = self.encoder.forward(inputs)
        # Resample noise if training
        if self.training:
            self.noise.normal_()
        # Reparametrize noise to embedding space
        inputs = mu + self.noise * torch.exp(0.5 * log_sigma)
        # Decode to image
        inputs = self.decoder(inputs)

        return inputs, mu, log_sigma

Когда я сейчас перемещаю авто-кодер в GPU с помощью net.to('cuda:0'), я получаю ошибку при пересылке, потому что тензор шума не перемещается.

Я неЯ не хочу добавлять параметр устройства в конструктор, потому что тогда все еще невозможно переместить его на другое устройство позже.Я также попытался обернуть шум в nn.Parameter так, чтобы на него влиял net.to(), но это выдает ошибку от оптимизатора, поскольку шум помечен как requires_grad=False.

У любого есть решениепереместить все модули с помощью net.to()?

Ответы [ 3 ]

0 голосов
/ 19 февраля 2019

После еще нескольких проб и ошибок я нашел два метода:

  1. Использовать буферы: заменив self._train_noise = torch.randn(batch_size, embedding_size) на self.register_buffer('_train_noise', torch.randn(batch_size, embedding_size), тензор шума добавляется в модуль в качестве буфера.Это позволяет net.to(device) влиять на это тоже.Кроме того, тензор теперь является частью state_dict.
  2. Override net.to(device): при этом шум остается за пределами state_dict.

    def to(device):
        new_self = super(VariationalGenerator, self).to(device)
        new_self._train_noise = new_self._train_noise.to(device)
        new_self._eval_noise = new_self._eval_noise.to(device)
    
        return new_self
    
0 голосов
/ 25 июля 2019

Лучшей версией второго подхода tilman151 является, вероятно, переопределение _apply, а не to.Таким образом, net.cuda(), net.float() и т. Д. Также будут работать, поскольку все они вызывают _apply, а не to (как видно из source , что проще, чем вы могли быдумаю):

def _apply(self, fn):
    super(VariationalGenerator, self)._apply(fn)
    self._train_noise = fn(self._train_noise)
    self._eval_noise = fn(self._eval_noise)
    return self
0 голосов
/ 15 февраля 2019

Используйте это:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Теперь и для модели, и для каждого используемого тензора

net.to(device)
input = input.to(device)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...