У меня есть пользовательский класс для определения сети:
class PyTorchUNet(Model):
....
def set_loss(self):
if self.activation_func == 'softmax': #this is working example
loss_function = partial(mixed_dice_cross_entropy_loss,
dice_loss=multiclass_dice_loss,
cross_entropy_loss=nn.CrossEntropyLoss(),
dice_activation='softmax',
dice_weight=self.architecture_config['model_params']['dice_weight'],
cross_entropy_weight=self.architecture_config['model_params']['bce_weight']
)
elif self.activation_func == 'sigmoid':
loss_function = designed_loss #setting will cause error on validation
else:
raise Exception('Only softmax and sigmoid activations are allowed')
self.loss_function = [('mask', loss_function, 1.0)]
def designed_loss(output, target):
target = target.long() # this should make variable to tensor
return lovasz_hinge(output, target)
# this is just as it from github
def lovasz_hinge(logits, labels, per_image=True, ignore=None):
"""
Binary Lovasz hinge loss
logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
per_image: compute the loss per image instead of per batch
ignore: void class id
"""
if per_image:
loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
for log, lab in zip(logits, labels))
else:
loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
return loss
def lovasz_hinge_flat(logits, labels):
"""
Binary Lovasz hinge loss
logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
labels: [P] Tensor, binary ground truth labels (0 or 1)
ignore: label to ignore
"""
if len(labels) == 0:
# only void pixels, the gradients should be 0
return logits.sum() * 0.
signs = 2. * labels.float() - 1.
errors = (1. - logits * Variable(signs))
errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
perm = perm.data
gt_sorted = labels[perm]
grad = lovasz_grad(gt_sorted)
loss = torch.dot(F.elu(errors_sorted), Variable(grad))
return loss
def mean(l, ignore_nan=False, empty=0):
"""
nanmean compatible with generators.
"""
l = iter(l)
if ignore_nan:
l = ifilterfalse(np.isnan, l)
try:
n = 1
acc = next(l)
except StopIteration:
if empty == 'raise':
raise ValueError('Empty mean')
return empty
for n, v in enumerate(l, 2):
acc += v
if n == 1:
return acc
return acc / n
Рабочий пример:
def mixed_dice_cross_entropy_loss(output, target, dice_weight=0.5, dice_loss=None,
cross_entropy_weight=0.5, cross_entropy_loss=None, smooth=0,
dice_activation='softmax'):
num_classes_without_background = output.size(1) - 1
dice_output = output[:, 1:, :, :]
dice_target = target[:, :num_classes_without_background, :, :].long()
cross_entropy_target = torch.zeros_like(target[:, 0, :, :]).long()
for class_nr in range(num_classes_without_background):
cross_entropy_target = where(target[:, class_nr, :, :], class_nr + 1, cross_entropy_target)
if cross_entropy_loss is None:
cross_entropy_loss = nn.CrossEntropyLoss()
if dice_loss is None:
dice_loss = multiclass_dice_loss
return dice_weight * dice_loss(dice_output, dice_target, smooth,
dice_activation) + cross_entropy_weight * cross_entropy_loss(output,
cross_entropy_target)
def multiclass_dice_loss(output, target, smooth=0, activation='softmax'):
"""Calculate Dice Loss for multiple class output.
Args:
output (torch.Tensor): Model output of shape (N x C x H x W).
target (torch.Tensor): Target of shape (N x H x W).
smooth (float, optional): Smoothing factor. Defaults to 0.
activation (string, optional): Name of the activation function, softmax or sigmoid. Defaults to 'softmax'.
Returns:
torch.Tensor: Loss value.
"""
if activation == 'softmax':
activation_nn = torch.nn.Softmax2d()
elif activation == 'sigmoid':
activation_nn = torch.nn.Sigmoid()
else:
raise NotImplementedError('only sigmoid and softmax are implemented')
loss = 0
dice = DiceLoss(smooth=smooth)
output = activation_nn(output)
num_classes = output.size(1)
target.data = target.data.float()
for class_nr in range(num_classes):
loss += dice(output[:, class_nr, :, :], target[:, class_nr, :, :])
return loss / num_classes
В результате я продолжаю получать:
RuntimeError: Variable data has to be a tensor, but got Variable
Как решить проблему?