Странное поведение Inception_v3 - PullRequest
0 голосов
/ 15 мая 2019

Я пытаюсь создать генеративную сеть на основе предварительно обученного Inception_v3.

1) Я фиксирую все веса в модели

2) создать переменную, размер которой (2, 3, 299, 299)

3) создайте цели размером (2, 1000), чтобы я хотел, чтобы активация моего последнего слоя стала как можно ближе к нему путем оптимизации переменной. (Я не устанавливаю размер пакета 1, потому что в отличие от VGG16, Inception_v3 не принимает размер пакета = 1, но это не главное).

Следующий код должен работать, но выдает ошибку: «RuntimeError: одна из переменных, необходимых для вычисления градиента, была изменена операцией на месте».

# minimalist code with Inception_v3 that throws the error:

import torch
from torch.autograd import Variable
import torch.optim as optim
import torch.nn as nn
import torchvision

torch.set_default_tensor_type('torch.FloatTensor')
Iv3 = torchvision.models.inception_v3(pretrained=True)
for i in Iv3.parameters():
    i.requires_grad = False

criterion = nn.CrossEntropyLoss()

x = Variable(torch.randn(2, 3, 299, 299), requires_grad=True)
target = torch.empty(2, dtype=torch.long).random_(1000)

output = Iv3(x)
loss = criterion(output[0], target)
loss.backward()

print(x.grad)

Это очень странно, потому что, если я делаю то же самое с VGG16, все работает нормально:

# minimalist working code with VGG16:

import torch
from torch.autograd import Variable
import torch.optim as optim
import torch.nn as nn
import torchvision

# torch.cuda.empty_cache()
# vgg16 = torchvision.models.vgg16(pretrained=True).cuda()
# torch.set_default_tensor_type('torch.cuda.FloatTensor')

torch.set_default_tensor_type('torch.FloatTensor')
vgg16 = torchvision.models.vgg16(pretrained=True)
for i in vgg16.parameters():
    i.requires_grad = False

criterion = nn.CrossEntropyLoss()

x = Variable(torch.randn(2, 3, 229, 229), requires_grad=True)
target = torch.empty(2, dtype=torch.long).random_(1000)

output = vgg16(x)
loss = criterion(output, target)
loss.backward()

print(x.grad)

Пожалуйста, помогите.

1 Ответ

1 голос
/ 15 мая 2019

Благодаря @iacolippo проблема решена. Оказывается, проблема была из-за Pytorch 1.0.0. Нет проблем с Pytorch 0.4.1. хотя.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...