Размер входного изображения 512 * 512, чтобы соответствовать вводу resnet. входное изображение Я использовал
_img = Image.open(self.images[index]).convert('RGB')
в загрузчике данных. Я использовал resnet50 в качестве магистрали моей сети без fc.Производительность формы
[4,2048,16,16]
, затем использовал два (conv bn relu) и интерполированный
def forward(self, input):
x=self.backbone(input)
x = self.conv1(x)
x= self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x= self.bn2(x)
x = self.relu(x)
x = F.interpolate(x, size=[512,512], mode='bilinear', align_corners=True)
return x
Часть обучения
self.criterion=nn.CrossEntropyLoss()
if self.args.cuda:
image, target = image.cuda(), target.cuda()
self.scheduler(self.optimizer, i, epoch, self.best_pred)
self.optimizer.zero_grad()
output = self.model(image)
loss = self.criterion(output, target.long())
loss.backward()
Но возникает ошибка
File "E:/python_workspace/1006/train.py", line 135, in training
loss = self.criterion(output, target.long())
File "E:\python_workspace\1006\utils\loss.py", line 28, in CrossEntropyLoss
loss = criterion(logit, target.long())
File "E:\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 547, in __call__
result = self.forward(*input, **kwargs)
File "E:\anaconda3\lib\site-packages\torch\nn\modules\loss.py", line 916, in forward
ignore_index=self.ignore_index, reduction=self.reduction)
File "E:\anaconda3\lib\site-packages\torch\nn\functional.py", line 1995, in cross_entropy
return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
File "E:\anaconda3\lib\site-packages\torch\nn\functional.py", line 1826, in nll_loss
ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed. at C:\w\1\s\tmp_conda_3.6_045031\conda\conda-bld\pytorch_1565412750030\work\aten\src\THNN/generic/SpatialClassNLLCriterion.c:111
image.shape is [4, 3, 512, 512],dtype is torch.float32
target.shape is [4, 512, 512],dtype is torch.float32
output.shape is [4, 3, 512, 512],dtype is torch.float32
целевое изображение Все целевые изображения имеют только три разных цвета. Поэтому я устанавливаю выход на 3 канала. И есть режим изображенияP Где могут быть проблемы с моим кодом?