Как исправить проблему обучения капсулы для одного класса набора данных MNIST? - PullRequest
0 голосов
/ 16 января 2019

Я тренируюсь в Капсульной сети с кодером и декодером. Он прекрасно работает со всеми классами (10 классов) набора данных MNIST. Но когда я извлекаю одно слово класса (класс 0 или класс 5), а затем тренирую капсульную сеть, восстановление изображения очень плохое.

Где мне нужно изменить настройки сети или у меня ошибка при подготовке данных?

Я пытался:

  1. Я изменил общий класс с 10 (для десяти цифр на 1 для 1 цифры и даже для 2 для 2 цифр).
  2. Когда я использую набор данных MNIST по умолчанию, я не получаю ошибку или размер тензора, но когда я извлекаю определенный класс и затем передаю его в сеть, я сталкиваюсь с такими проблемами, как: а) проблемы с размерами; б) тензор с плавающей точкой. предупреждение.

Я исправил эти вещи, но вручную добавил измерение и преобразовал данные в тензор data.float (). Cuda (). Я сделал это как для случая, то есть когда я использую 10-значные капсулы, так и когда я использую 1-значные капсулы для обучения одного класса.

Но после этого сеть работает нормально, но я получаю очень размытые и плохие реконструкции. Хотя, когда я тренирую весь набор данных MNIST, не извлекая какой-либо класс и не передавая его в сеть, он не выдает никакой ошибки, и реконструкция работает очень хорошо.

Я хотел бы поделиться более подробной информацией и другими частями кода -

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim import Adam
from torchvision import datasets, transforms

USE_CUDA = True

### **Here we prepare the data for the complete 10 class digit training**###
class Mnist:
    def __init__(self, batch_size):
        dataset_transform = transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])

        train_dataset = datasets.MNIST('../data', train=True, download=True, transform=dataset_transform)
        test_dataset = datasets.MNIST('../data', train=False, download=True, transform=dataset_transform)

        self.train_loader  = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        self.test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

## **Here is my code for extracting a single class digit extraction**##
class Mnist:
    def __init__(self,batch_size):

        dataset_transform = transforms.Compose([ 
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307,), (0.3081,))
        ])


        train_mnist = datasets.MNIST("../data", train=True)  
        test_mnist = datasets.MNIST("../data", train= False)
        train_image, train_label = train_mnist.train_data, train_mnist.train_labels
        test_image, test_label = test_mnist.test_data, test_mnist.test_labels

        train_0, test_0 = [train_image[key] for (key, label) in enumerate(train_label) if int(label) == 5],[test_image[key] for (key, label) in enumerate(test_label) if int(label) == 5]
        train_label_0, test_label_0 = zero__train = [train_label[key] for (key, label) in enumerate(train_label) if int(label) == 5],[test_label[key] for (key, label) in enumerate(test_label) if int(label) == 5]

        train_dataset = tuple(zip(train_0, train_label_0))
        test_dataset = tuple(zip(test_0, test_label_0))

        self.train_loader  = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        self.test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

# Here is the main code for the capsule training.

''' The below code is used for training the 1 class but using the 10 Digit capsules
'''
class ConvLayer(nn.Module):
    def __init__(self, in_channels=1, out_channels=256, kernel_size=9):
        super(ConvLayer, self).__init__()

        self.conv = nn.Conv2d(in_channels=in_channels,
                               out_channels=out_channels,
                               kernel_size=kernel_size,
                               stride=1
                             )

    def forward(self, x):
        return F.relu(self.conv(x))
class PrimaryCaps(nn.Module):
    def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=9):
        super(PrimaryCaps, self).__init__()

        self.capsules = nn.ModuleList([
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=0) 
                          for _ in range(num_capsules)])

    def forward(self, x):
        u = [capsule(x) for capsule in self.capsules]
        u = torch.stack(u, dim=1)
        u = u.view(x.size(0), 32 * 6 * 6, -1)
        return self.squash(u)

    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor
class DigitCaps(nn.Module):
    def __init__(self, num_capsules=10, num_routes=32 * 6 * 6, in_channels=8, out_channels=16):
        super(DigitCaps, self).__init__()

        self.in_channels = in_channels
        self.num_routes = num_routes
        self.num_capsules = num_capsules

        self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, out_channels, in_channels))

    def forward(self, x):
        batch_size = x.size(0)
        x = torch.stack([x] * self.num_capsules, dim=2).unsqueeze(4)
#         print(f"x at epoch {epoch} is equal to : {x}")

        W = torch.cat([self.W] * batch_size, dim=0)
#         print(f"W at epoch {epoch} is equal to : {W}")
        u_hat = torch.matmul(W, x)
#         print(f"u_hatat epoch {epoch} is equal to : {u_hat}")

        b_ij = Variable(torch.zeros(1, self.num_routes, self.num_capsules, 1))
        if USE_CUDA:
            b_ij = b_ij.cuda()
#             print(f"b_ij at epoch {epoch} is equal to : {b_ij}")

        num_iterations = 3
        for iteration in range(num_iterations):
            c_ij = F.softmax(b_ij, dim =1)
            c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)

            s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
            v_j = self.squash(s_j)
#             print(f"b_ij at iteration {iteration} is equal to : {b_ij}")
            if iteration < num_iterations - 1:
                a_ij = torch.matmul(u_hat.transpose(3, 4), torch.cat([v_j] * self.num_routes, dim=1))
                b_ij = b_ij + a_ij.squeeze(4).mean(dim=0, keepdim=True)

        return v_j.squeeze(1)

    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()

        self.reconstraction_layers = nn.Sequential(
            nn.Linear(16 * 10, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 784),
            nn.Sigmoid()
        )

    def forward(self, x, data):
        classes = torch.sqrt((x ** 2).sum(2))
        classes = F.softmax(classes, dim =1)

        _, max_length_indices = classes.max(dim=1)
        masked = Variable(torch.sparse.torch.eye(10))
        if USE_CUDA:
            masked = masked.cuda()
        masked = masked.index_select(dim=0, index=max_length_indices.squeeze(1).data)

        reconstructions = self.reconstraction_layers((x * masked[:, :, None, None]).view(x.size(0), -1))
        reconstructions = reconstructions.view(-1, 1, 28, 28)

        return reconstructions, masked
class CapsNet(nn.Module):
    def __init__(self):
        super(CapsNet, self).__init__()
        self.conv_layer = ConvLayer()
        self.primary_capsules = PrimaryCaps()
        self.digit_capsules = DigitCaps()
        self.decoder = Decoder()

        self.mse_loss = nn.MSELoss()

    def forward(self, data):
        output = self.digit_capsules(self.primary_capsules(self.conv_layer(data)))
        reconstructions, masked = self.decoder(output, data)
        return output, reconstructions, masked

    def loss(self, data, x, target, reconstructions):
        return self.margin_loss(x, target) + self.reconstruction_loss(data, reconstructions)
#         return self.reconstruction_loss(data, reconstructions)

    def margin_loss(self, x, labels, size_average=True):
        batch_size = x.size(0)

        v_c = torch.sqrt((x**2).sum(dim=2, keepdim=True))

        left = F.relu(0.9 - v_c).view(batch_size, -1)
        right = F.relu(v_c - 0.1).view(batch_size, -1)
#         print(f"shape of labels, left and right respectively - {labels.size(), left.size(), right.size()}")

        loss = labels * left + 0.5 * (1.0 - labels) * right
        loss = loss.sum(dim=1).mean()

        return loss

    def reconstruction_loss(self, data, reconstructions):
        loss = self.mse_loss(reconstructions.view(reconstructions.size(0), -1), data.view(reconstructions.size(0), -1))
        return loss*0.0005 

capsule_net = CapsNet()
if USE_CUDA:
    capsule_net = capsule_net.cuda()
optimizer = Adam(capsule_net.parameters())
capsule_net

##### Here is the problem while training####
batch_size = 100
mnist = Mnist(batch_size)

n_epochs = 5


for epoch in range(n_epochs):
    capsule_net.train()
    train_loss = 0
    for batch_id, (data, target) in enumerate(mnist.train_loader):

        target = torch.eye(10).index_select(dim=0, index=target)
        data, target = Variable(data), Variable(target)

        if USE_CUDA:
            data, target = data.cuda(), target.cuda()
            data, target = data.float().cuda(), target.float().cuda() # Here I changed the data to float and it's required only when I am using my extracted dataset for a single class
            data = data[:,:,:] # Use this when 1st MNist data is used
#             data = data[:,None,:,:] # Use this when I am using my extracted single class digits

        optimizer.zero_grad()
        output, reconstructions, masked = capsule_net(data)
        loss = capsule_net.loss(data, output, target, reconstructions)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

#         if batch_id % 100 == 0:
#             print ("train accuracy:", sum(np.argmax(masked.data.cpu().numpy(), 1) == 
#                                    np.argmax(target.data.cpu().numpy(), 1)) / float(batch_size))

    print (train_loss / len(mnist.train_loader))

I used this to see the main data as image and the reconstructed image

import matplotlib
import matplotlib.pyplot as plt

def plot_images_separately(images):
    "Plot the six MNIST images separately."
    fig = plt.figure()
    for j in range(1, 10):
        ax = fig.add_subplot(1, 10, j)
        ax.matshow(images[j-1], cmap = matplotlib.cm.binary)
        plt.xticks(np.array([]))
        plt.yticks(np.array([]))
    plt.show()
plot_images_separately(data[:10,0].data.cpu().numpy())
plot_images_separately(reconstructions[:10,0].data.cpu().numpy())

result picture

1 Ответ

0 голосов
/ 18 января 2019

Я проверил нормально работающий код, а затем проблемный, и обнаружил, что набор данных, передаваемый в сеть, не имеет такой же природы. Проблемы были -

  1. Данные MNIST, извлеченные для одного класса, не были преобразованы в тензор, и нормализация не применялась, хотя я попытался передать их через преобразование.

Это то, что я сделал, чтобы это исправить -

  1. Я создал возражения на преобразование и тензорное возражение, а затем передал ему элементы понимания списка. Ниже приведены коды и окончательный вывод моей сети -

    Подготовка набора данных класса 0 (набор данных для цифры 5)

    класс Мнист: trans = transforms.ToTensor () normalize = transforms.Normalize ((0.1307,), (0.3081,)) def init (self, batch_size):

        dataset_transform = transforms.Compose([ 
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307,), (0.3081,))
        ])
    
        trans = transforms.ToTensor()
        normalize = transforms.Normalize((0.1307,), (0.3081,))
        train_mnist = datasets.MNIST("../data", train=True, transform=dataset_transform)  
        test_mnist = datasets.MNIST("../data", train= False, transform=dataset_transform)
        train_image, train_label = train_mnist.train_data, train_mnist.train_labels
        test_image, test_label = test_mnist.test_data, test_mnist.test_labels
    
    
        train_0, test_0 = [normalize(trans(train_image[key].unsqueeze(2).numpy())) for (key, label) in enumerate(train_label) if int(label) == 5],[test_image[key] for (key, label) in enumerate(test_label) if int(label) == 5]
        train_label_0, test_label_0 = zero__train = [train_label[key] for (key, label) in enumerate(train_label) if int(label) == 5],[test_label[key] for (key, label) in enumerate(test_label) if int(label) == 5]
    
        train_dataset = tuple(zip(train_0, train_label_0))
        test_dataset = tuple(zip(test_0, test_label_0))
    
        self.train_loader  = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        self.test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    

введите описание изображения здесь

...