Dice и CE потеря не тренировочная сеть вместе - PullRequest
0 голосов
/ 19 ноября 2018

Я тренирую сеть сегментации по вызову Kaggle Salt. Мои кости и ce уменьшаются, но затем внезапно кости увеличиваются, и CE немного подпрыгивает, это продолжает происходить с игральными костями. Я весь день пытался это исправить, но не могу запустить мой код. Я использую только 10 точек данных, чтобы превзойти мои данные, но этого просто не происходит. Любая помощь будет принята с благодарностью.

Участки игры в кости (вверху) и CE:

Кривая потерь

Вот мои кости и поезд:

def dice(input, target,weights=torch.tensor([1,1]).float().cuda()):
    smooth=.001

    dummy=np.zeros([batch_size,2,100,100]) # create dummy to one hot encode target for weighted dice
    dummy[:,0,:,:][target==0]=1 # background class is 0
    dummy[:,1,:,:][target==1]=1 # salt class is 1 


    target=torch.tensor(dummy).float().cuda()

#     print(input.size(),input[:,0,:,:].size())
    input1=input[:,0,:,:].contiguous().view(-1) #flatten both classes seperately
    target1=target[:,0,:,:].contiguous().view(-1)

    input2=input[:,1,:,:].contiguous().view(-1)
    target2=target[:,1,:,:].contiguous().view(-1)

    score1=2*(input1*target1).sum()/(input1.sum()+target1.sum()+smooth) #back
    score2=2*(input2*target2).sum()/(input2.sum()+target2.sum()+smooth) #salt


    score=1-(weights[0]*score1+weights[1]*score2)/2
    if score<0:
        score=score-score

    return(score)
Heres the train:


def train(epoch):
    for idx, batch_data in enumerate(dataloader) : 
        x, target=batch_data['image'].float().cuda(),batch_data['label'].float().cuda()


        optimizer.zero_grad()
        output = net(x)
#         print(output.size())
        output.squeeze_(1)

#         print('out',output.size(),target.size())
        bce_loss = criterion(output, target.long())
        lc.append(bce_loss.item())

        dice_loss = dice((output), target)
        ld.append(dice_loss.item())
        loss =  dice_loss + bce_loss
        l.append(loss.item())

        loss.backward()
        optimizer.step()

        print('Epoch {}, loss {}, bce {}, dice {}'.format(
            epoch, sum(l)/len(l), sum(lc)/len(lc) , sum(ld)/len(ld) ))

Вот остальная часть кода (я сбит с ядра gaggle): https://github.com/bluesky314/Salt-Segmentation/blob/master/kernel-2.ipynb 1 (здесь показано, как я запускаю эту ячейку (14) второй раз, чтобы не возникали взлеты и падения) но видно по сюжету)

dataset=DatasetSalt(limit_paths=10) просто ограничивает набор данных любым числом, используя только верхние пути для получения изображений из

Был бы очень признателен за любую помощь, буквально боролся с этим 8+ часов

...