Использование Upsample вместо ConvTranspose2d приводит к нехватке памяти на этапе вычисления градиента - PullRequest
0 голосов
/ 26 сентября 2018

Видеокарта: gtx1070ti 8 Гб, размер пакета 64, размер входного изображения 128 * 128.У меня был такой UNET с resnet152 в качестве кодировщика, который работал довольно хорошо:

class UNetResNet(nn.Module):

def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,
             pretrained=False, is_deconv=False):
    super().__init__()
    self.num_classes = num_classes
    self.dropout_2d = dropout_2d

    if encoder_depth == 34:
        self.encoder = torchvision.models.resnet34(pretrained=pretrained)
        bottom_channel_nr = 512
    elif encoder_depth == 101:
        self.encoder = torchvision.models.resnet101(pretrained=pretrained)
        bottom_channel_nr = 2048
    elif encoder_depth == 152:
        self.encoder = torchvision.models.resnet152(pretrained=pretrained)
        bottom_channel_nr = 2048

    else:
        raise NotImplementedError('only 34, 101, 152 version of Resnet are implemented')

    self.pool = nn.MaxPool2d(2, 2)

    self.relu = nn.ReLU(inplace=True)

    self.conv1 = nn.Sequential(self.encoder.conv1,
                               self.encoder.bn1,
                               self.encoder.relu,
                               self.pool) #from that pool layer I would like to get rid off

    self.conv2 = self.encoder.layer1
    self.conv3 = self.encoder.layer2
    self.conv4 = self.encoder.layer3
    self.conv5 = self.encoder.layer4
    self.center = DecoderCenter(bottom_channel_nr, num_filters * 8 *2, num_filters * 8, False)

    self.dec5 =  DecoderBlockV(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8,   is_deconv)
    self.dec4 = DecoderBlockV(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
    self.dec3 = DecoderBlockV(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
    self.dec2 = DecoderBlockV(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,
                               is_deconv)
    self.dec1 = DecoderBlockV(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
    self.dec0 = ConvRelu(num_filters, num_filters)
    self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)

def forward(self, x):
    conv1 = self.conv1(x)
    conv2 = self.conv2(conv1)
    conv3 = self.conv3(conv2)
    conv4 = self.conv4(conv3)
    conv5 = self.conv5(conv4) 
    center = self.center(conv5)
    dec5 = self.dec5(torch.cat([center, conv5], 1))
    dec4 = self.dec4(torch.cat([dec5, conv4], 1))
    dec3 = self.dec3(torch.cat([dec4, conv3], 1))
    dec2 = self.dec2(torch.cat([dec3, conv2], 1))
    dec1 = self.dec1(dec2)
    dec0 = self.dec0(dec1)

    return self.final(F.dropout2d(dec0, p=self.dropout_2d))
# blocks
    class DecoderBlockV(nn.Module):
        def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
            super(DecoderBlockV2, self).__init__()
            self.in_channels = in_channels

            if is_deconv:
                self.block = nn.Sequential(
                    ConvRelu(in_channels, middle_channels),
                    nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
                                       padding=1),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(inplace=True)

                )
            else:


                self.block = nn.Sequential(
                    nn.Upsample(scale_factor=2, mode='bilinear'),
                    ConvRelu(in_channels, middle_channels),
                    ConvRelu(middle_channels, out_channels),
                )

        def forward(self, x):
            return self.block(x)



class DecoderCenter(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
        super(DecoderCenter, self).__init__()
        self.in_channels = in_channels


        if is_deconv:
            """
                Paramaters for Deconvolution were chosen to avoid artifacts, following
                link https://distill.pub/2016/deconv-checkerboard/
            """

            self.block = nn.Sequential(
                ConvRelu(in_channels, middle_channels),
                nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
                                   padding=1),
        nn.BatchNorm2d(out_channels), 
                nn.ReLU(inplace=True)
            )
        else:
            self.block = nn.Sequential(
                ConvRelu(in_channels, middle_channels),
                ConvRelu(middle_channels, out_channels)

            )

    def forward(self, x):
        return self.block(x)

Затем я хотел использовать Upsample вместо ConvTranspose2d, который вызывает cuda из памяти на этапе вычисления градиента, просто установив is_deconv = False в DecoderBlock -это вызвало проблему на этапе вычисления градиента.Зачем?Исправлено только уменьшением размера партии с 64 до 40

~/anaconda3/lib/python3.6/site-packages/steppy/base.py in fit_transform(self, *args, **kwargs)
    603             dict: outputs
    604         """
--> 605         self.fit(*args, **kwargs)
    606         return self.transform(*args, **kwargs)
    607 

~/Desktop/ml/salt/open-solution-salt-identification-master/common_blocks/models.py in fit(self, datagen, validation_datagen, meta_valid)
     76             for batch_id, data in enumerate(batch_gen):
     77                 self.callbacks.on_batch_begin()
---> 78                 metrics = self._fit_loop(data)
     79                 self.callbacks.on_batch_end(metrics=metrics)
     80                 if batch_id == steps:

~/Desktop/ml/salt/open-solution-salt-identification-master/common_blocks/models.py in _fit_loop(self, data)
    113             batch_loss = sum(partial_batch_losses.values())
    114         partial_batch_losses['sum'] = batch_loss
--> 115         batch_loss.backward()
    116         self.optimizer.step()
    117 

~/anaconda3/lib/python3.6/site-packages/torch/autograd/variable.py in backward(self, gradient, retain_graph, create_graph, retain_variables)
    165                 Variable.
    166         """
--> 167         torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
    168 
    169     def register_hook(self, hook):

~/anaconda3/lib/python3.6/site-packages/torch/autograd/__init__.py in backward(variables, grad_variables, retain_graph, create_graph, retain_variables)
     97 
     98     Variable._execution_engine.run_backward(
---> 99         variables, grad_variables, retain_graph)
    100 
    101 

RuntimeError: cuda runtime error (2) : out of memory at /opt/conda/conda-bld/pytorch_1518244421288/work/torch/lib/THC/generic/THCStorage.cu:58
...