Pytorch: правильный способ использования пользовательских весовых карт в архитектуре Unet - PullRequest
2 голосов
/ 14 октября 2019

В архитектуре u-net существует известная хитрость, заключающаяся в использовании пользовательских весовых карт для повышения точности. Ниже приведены подробности этого -

enter image description here

Теперь, спрашивая здесь и в нескольких других местах, я узнаю о 2 подходах. Я хочу знать, какой из них правильный, или есть ли другой правильный подход, который является более правильным?

1) Во-первых, этоиспользуйте метод torch.nn.Functional в обучающем цикле -

loss = torch.nn.functional.cross_entropy(output, target, w), где w будет вычисленным пользовательским весом.

2) Второй - использовать reduction='none' при вызове функции потерь за пределамицикл обучения criterion = torch.nn.CrossEntropy(reduction='none')

, а затем в цикле обучения, умноженном на произвольный вес -

gt # Ground truth, format torch.long
pd # Network output
W # per-element weighting based on the distance map from UNet
loss = criterion(pd, gt)
loss = W*loss # Ensure that weights are scaled appropriately
loss = torch.sum(loss.flatten(start_dim=1), axis=0) # Sums the loss per image
loss = torch.mean(loss) # Average across a batch

Теперь я немного растерян, какой из них правильный или есть какой-то другой способ,или оба правы?

...