Ошибка в ret = torch._C._nn.nll_loss2d (вход, цель, вес, _Reduction.get_enum (сокращение), ignore_index) - PullRequest
1 голос
/ 06 октября 2019

Размер входного изображения 512 * 512, чтобы соответствовать вводу resnet. входное изображение Я использовал

_img = Image.open(self.images[index]).convert('RGB')

в загрузчике данных. Я использовал resnet50 в качестве магистрали моей сети без fc.Производительность формы

[4,2048,16,16]

, затем использовал два (conv bn relu) и интерполированный

    def forward(self, input):
        x=self.backbone(input)
        x = self.conv1(x)
        x= self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x= self.bn2(x)
        x = self.relu(x)
        x = F.interpolate(x, size=[512,512], mode='bilinear', align_corners=True)
        return x

Часть обучения

    self.criterion=nn.CrossEntropyLoss()
    if self.args.cuda:
        image, target = image.cuda(), target.cuda()
    self.scheduler(self.optimizer, i, epoch, self.best_pred)
    self.optimizer.zero_grad()
    output = self.model(image)
    loss = self.criterion(output, target.long())
    loss.backward()

Но возникает ошибка

File "E:/python_workspace/1006/train.py", line 135, in training
loss = self.criterion(output, target.long())
File "E:\python_workspace\1006\utils\loss.py", line 28, in CrossEntropyLoss
loss = criterion(logit, target.long())
File "E:\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 547, in __call__
result = self.forward(*input, **kwargs)
File "E:\anaconda3\lib\site-packages\torch\nn\modules\loss.py", line 916, in forward
ignore_index=self.ignore_index, reduction=self.reduction)
File "E:\anaconda3\lib\site-packages\torch\nn\functional.py", line 1995, in cross_entropy
return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
File "E:\anaconda3\lib\site-packages\torch\nn\functional.py", line 1826, in nll_loss
ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed.  at C:\w\1\s\tmp_conda_3.6_045031\conda\conda-bld\pytorch_1565412750030\work\aten\src\THNN/generic/SpatialClassNLLCriterion.c:111

image.shape is [4, 3, 512, 512],dtype is torch.float32
target.shape is [4, 512, 512],dtype is torch.float32
output.shape is [4, 3, 512, 512],dtype is torch.float32

целевое изображение Все целевые изображения имеют только три разных цвета. Поэтому я устанавливаю выход на 3 канала. И есть режим изображенияP Где могут быть проблемы с моим кодом?

1 Ответ

1 голос
/ 06 октября 2019

Судя по размерам ваших теннсоров, ваш batch_size=4. Вы пытаетесь предсказать одну из трех меток на пиксель, то есть n_classes=3.

Полученная ошибка:

RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed.

Означает, что target.long()Вы предоставляете в свою функцию потерь значения, отрицательные или превышающие n_classes.

Проверьте, как вы читаете наземные ярлыки правды. Если это изображение типа P, вам необходимо прочитать его как таковое и не преобразовывать его в значения RGB.

PS,
Do not use align_corners=True in F.interpolate, это вносит искажения.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...