Я строю Variational Autoencoder (VAE) в PyTorch и у меня проблемы с написанием кода, не зависящего от устройства.Автоэнкодер является потомком nn.Module
с сетью кодировщика и декодера, которые тоже.Все веса сети могут быть перемещены с одного устройства на другое путем вызова net.to(device)
.
Проблема, с которой я столкнулся, заключается в уловке репараметризации:
encoding = mu + noise * sigma
Шум представляет собой тензортот же размер, что и mu
и sigma
и сохранен как переменная-член модуля автоэнкодера.Он инициализируется в конструкторе и пересэмплируется на каждом этапе обучения.Я делаю это таким образом, чтобы не создавать новый тензор шума на каждом шаге и не выдвигать его на нужное устройство.Кроме того, я хочу исправить шум в оценке.Вот код:
class VariationalGenerator(nn.Module):
def __init__(self, input_nc, output_nc):
super(VariationalGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
embedding_size = 128
self._train_noise = torch.randn(batch_size, embedding_size)
self._eval_noise = torch.randn(1, embedding_size)
self.noise = self._train_noise
# Create encoder
self.encoder = Encoder(input_nc, embedding_size)
# Create decoder
self.decoder = Decoder(output_nc, embedding_size)
def train(self, mode=True):
super(VariationalGenerator, self).train(mode)
self.noise = self._train_noise
def eval(self):
super(VariationalGenerator, self).eval()
self.noise = self._eval_noise
def forward(self, inputs):
# Calculate parameters of embedding space
mu, log_sigma = self.encoder.forward(inputs)
# Resample noise if training
if self.training:
self.noise.normal_()
# Reparametrize noise to embedding space
inputs = mu + self.noise * torch.exp(0.5 * log_sigma)
# Decode to image
inputs = self.decoder(inputs)
return inputs, mu, log_sigma
Когда я сейчас перемещаю авто-кодер в GPU с помощью net.to('cuda:0')
, я получаю ошибку при пересылке, потому что тензор шума не перемещается.
Я неЯ не хочу добавлять параметр устройства в конструктор, потому что тогда все еще невозможно переместить его на другое устройство позже.Я также попытался обернуть шум в nn.Parameter
так, чтобы на него влиял net.to()
, но это выдает ошибку от оптимизатора, поскольку шум помечен как requires_grad=False
.
У любого есть решениепереместить все модули с помощью net.to()
?