Модификация Caffe VGG 16 для обработки 1-канальных изображений на PyTorch - PullRequest
0 голосов
/ 23 января 2019

Я преобразую сеть VGG16 в полностью сверточную сеть, а также модифицирую вход для приема одноканального изображения.Полный код для воспроизведения приведен ниже.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import numpy as np
import torchvision.datasets as datasets
import copy

from torch.utils import model_zoo
from torchvision import models
from collections import OrderedDict

def convolutionalize(modules, input_size):
    """
    Recast `modules` into fully convolutional form
    """
    fully_conv_modules = []
    x = Variable(torch.zeros((1, ) + input_size))
    for m in modules:
         if isinstance(m, nn.Linear):
              n = nn.Conv2d(x.size(1), m.weight.size(0), kernel_size=(x.size(2), x.size(3)))
              n.weight.data.view(-1).copy_(m.weight.data.view(-1))
              n.bias.data.view(-1).copy_(m.bias.data.view(-1))
              m = n
         fully_conv_modules.append(m)
         x = m(x)
    return fully_conv_modules



def vgg16(is_caffe=True):
     """
     Load the VGG-16 net for use as a fully convolutional backbone.
     """
     vgg16 = models.vgg16(pretrained=True)
     # cast into fully convolutional form (as list of layers)
     vgg16 = convolutionalize(list(vgg16.features) + list(vgg16.classifier),
                         (3, 224, 224))
     # name layers like the original paper
     names = ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
    'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
    'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3',
    'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'pool4',
    'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'pool5',
    'fc6', 'relu6', 'drop6', 'fc7', 'relu7', 'drop7', 'fc8']

    vgg16 = nn.Sequential(OrderedDict(zip(names, vgg16)))

    if is_caffe:
        # substitute original Caffe weights for improved fine-tuning accuracy
        # see https://github.com/jcjohnson/pytorch-vgg
        caffe_params = model_zoo.load_url('https://s3-us-west-2.amazonaws.com/'
                                      'jcjohns-models/vgg16-00b39a1b.pth')
        for new_p, old_p in zip(vgg16.parameters(), caffe_params.values()):
            new_p.data.copy_(old_p.view_as(new_p))
        # surgery: decapitate final classifier
   del vgg16._modules['fc8']  # note: risky use of private interface
   # surgery: keep fuller spatial dims by including incomplete pooling regions
   for m in vgg16.modules():
       if isinstance(m, nn.MaxPool2d):
                m.ceil_mode = True
   return vgg16


class Learner(nn.Module):

     def __init__(self, num_classes, singleChannel=False):
          super().__init__()

          backbone = vgg16(is_caffe=True)
          for k in list(backbone._modules)[-6:]:
                del backbone._modules[k]


          supp_backbone = copy.deepcopy(backbone)

          # Modify conv1_1 of conditioning branch to have 1 input channels
          # Init the weights in the new channels to the channel-wise mean
          # of the pre-trained conv1_1 weights
          if singleChannel==True:
               old_conv1 = backbone._modules['conv1_1'].weight.data
               mean_conv1 = torch.mean(old_conv1, dim=1, keepdim=True)
               new_conv1 = nn.Conv2d(1, old_conv1.size(0), kernel_size=old_conv1.size(2), stride=1, padding=1)
               new_conv1.weight.data = mean_conv1
               new_conv1.bias.data = backbone._modules['conv1_1'].bias.data
               backbone._modules['conv1_1'] = new_conv1

          self.encoder = copy.deepcopy(backbone)
          self.num_classes=num_classes

     def forward(self,im):

          # encode image
          supp_feats = self.encoder(im)

          return supp_feats




 model=Learner(num_classes=2,singleChannel=True).cpu()
 mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
 im2arr = np.array(mnist_trainset[1][0])
 im2arr = im2arr[np.newaxis,:, :,] # shape(1,28,28)

 model.train()
 x=model(torch.from_numpy(im2arr))

Я ожидал, что x будет выводом тензора факела, но получил сообщение об ошибке «ValueError: Ожидаемый 4D-тензор в качестве ввода, вместо этого получил 3D-тензор».на последней строке

1 Ответ

0 голосов
/ 23 января 2019

Ваша форма ввода должна быть Batch-Channel-Height-Width, которая равна 4D. В вашем случае у вас есть только один канал, так что вы «выдавили» это одноэлементное измерение, но pytorch это не нравится!

попробовать

im2arr = im2arr[np.newaxis, np.newaxis, :, :]  # add singleton for the channles as well
...