когда я тренирую свою сеть, получена ошибка, ожидаемый размер цели (4, 224), факел. Размер ([4, 224, 224]) - PullRequest
0 голосов
/ 22 марта 2020

В моем train.py

criteon = nn.CrossEntropyLoss()
loss = criteon(binary_output_c1,labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()

оба двоичных_произведения_c1, размер меток равен [4, 224,224], 4 означает размер пакета, 224 означает h и w. и он получил такую ​​ошибку 'size to [4,256,224,224], где 256 - количество классов. код здесь

model.train()
outputs = model(imgs)   # output  B * C * H *W
output_c1 = outputs[:,1,:,:] # 2 channels ,I choose the second channel
Rounding_output_c1 = torch.round(output_c1)
labelss =  torch.stack([(labels == i).long() for i in range(256)])
labelss = labelss.permute(1,0,2,3)
Rounding_output_c11 = torch.stack([(Rounding_output_c1 == i).float() for i in range(256)])
Rounding_output_c11 = Rounding_output_c11.permute(1,0,2,3)
loss = criteon(Rounding_output_c11,labelss)
optimizer.zero_grad()
loss.backward()

Ошибка тоже получается

Traceback (most recent call last):
  File "F:/experiment_code/U-net/train_2.py", line 76, in <module>
    loss = criteon(Rounding_output_c11,labelss)
  File "D:\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "D:\Anaconda3\lib\site-packages\torch\nn\modules\loss.py", line 942, in forward
    ignore_index=self.ignore_index, reduction=self.reduction)
  File "D:\Anaconda3\lib\site-packages\torch\nn\functional.py", line 2056, in cross_entropy
    return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
  File "D:\Anaconda3\lib\site-packages\torch\nn\functional.py", line 1873, in nll_loss
    ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: 1only batches of spatial targets supported (non-empty 3D tensors) but got targets of size: : [4, 256, 224, 224]

1 Ответ

0 голосов
/ 22 марта 2020

Если вы используете nn.CrossEntropyLoss, тогда ваш прогноз должен иметь два канала: один для прогнозирования 0 и другой для прогнозирования 1. Это немного избыточно, но потеря предполагает, что прогноз будет иметь #channels == # метки.

В качестве альтернативы, вы можете согласовать прогноз перед передачей к потере:

loss = criteon(torch.cat((-binary_output_c1[:, None, ...], binary_output_c1[:, None,...]), dim=1),labels)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...