Градиенты равны нулю при ручном обновлении весов в PyTorch - PullRequest
0 голосов
/ 29 ноября 2018

Я пытаюсь реализовать простую нейронную сеть с ручным обновлением весов для MNIST с помощью AUTOGRAD, аналогично примеру AUTOGRAD, приведенному здесь .Это мой код:

import os
import sys

import torch
import torchvision
class Datasets:
    """Helper for extracting datasets."""

    def __init__(self, root='data/', batch_size=25):
        if not os.path.exists(root):
            os.mkdir(root)
        self.root = root
        self.batch_size = batch_size

    def get_mnist_loaders(self):
        train_data = torchvision.datasets.MNIST(
                root=self.root, train=True, download=True)
        test_data = torchvision.datasets.MNIST(
                root=self.root, train=False, download=True)


        train_loader = torch.utils.data.DataLoader(
                dataset=train_data, batch_size=self.batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(
                dataset=test_data, batch_size=self.batch_size, shuffle=False)

        return train_loader, test_loader

    def create_batches(self, data, labels, batch_size):
        return [(data[i:i+batch_size], labels[i:i+batch_size])
            for i in range(0, len(data), max(1, batch_size))]

def train1():
    dtype = torch.float
    n_inputs = 28*28
    n_hidden1 = 300
    n_hidden2 = 100
    n_outputs = 10
    batch_size = 200
    n_epochs = 25
    learning_rate = 0.01
    test_step = 100 
    device = torch.device("cpu")

    datasets = Datasets(batch_size=batch_size)
    train_loader, test_loader = datasets.get_mnist_loaders()

    def feed_forward(X):
        x_shape = list(X.size())
        X = X.view(x_shape[0], x_shape[1]*x_shape[2])
        hidden1 = torch.mm(X, w1)
        hidden1 += b1
        hidden1 = hidden1.clamp(min=0)
        hidden2 = torch.mm(hidden1, w2) + b2
        hidden2 = hidden2.clamp(min=0)
        logits = torch.mm(hidden2, w3) + b3
        softmax = pytorch_softmax(logits)
        return softmax

    def accuracy(y_pred, y):

        if list(y_pred.size()) != list(y.size()):
            raise ValueError('Inputs have different shapes.')

        total_correct = 0
        total = 0
        for i, (y1, y2) in enumerate(zip(y_pred, y)):
            if y1 == y2:
                total_correct += 1
            total += 1

        return total_correct / total

    w1 = torch.randn(n_inputs, n_hidden1, device=device, dtype=dtype, requires_grad=True)
    b1 = torch.nn.Parameter(torch.zeros(n_hidden1), requires_grad=True)

    w2 = torch.randn(n_hidden1, n_hidden2, requires_grad=True)
    b2 = torch.nn.Parameter(torch.zeros(n_hidden2), requires_grad=True)

    w3 = torch.randn(n_hidden2, n_outputs, dtype=dtype, requires_grad=True)
    b3 = torch.nn.Parameter(torch.zeros(n_outputs), requires_grad=True)

    pytorch_softmax = torch.nn.Softmax(0)
    pytorch_cross_entropy = torch.nn.CrossEntropyLoss(reduction='elementwise_mean')

    step = 0
    for epoch in range(n_epochs):
        batches = datasets.create_batches(train_loader.dataset.train_data,
                                          train_loader.dataset.train_labels,
                                          batch_size)
        for x, y in batches:
            step += 1

            softmax = feed_forward(x.float())
            vals, y_pred = torch.max(softmax, 1)
            accuracy_ = accuracy(y_pred, y)
            cross_entropy = pytorch_cross_entropy(softmax, y)

            print(epoch, step, cross_entropy.item(), accuracy_)

            cross_entropy.backward()

            with torch.no_grad():
                w1 -= learning_rate * w1.grad
                w2 -= learning_rate * w2.grad
                w3 -= learning_rate * w3.grad

                b1 -= learning_rate * b1.grad
                b2 -= learning_rate * b2.grad
                b3 -= learning_rate * b3.grad

                w1.grad.zero_()
                w2.grad.zero_()
                w3.grad.zero_()

                b1.grad.zero_()
                b2.grad.zero_()
                b3.grad.zero_()

if __name__ == '__main__':
    train1()

Однако сеть не работает.Когда я печатаю части градиентов (например, w1.grad.data[:10, :10]), они состоят из нулей.Я пытался использовать weight.data и weight.grad.data для обновления весов и пытался удалить часть w.grad.zero_() (хотя это в примере), но это не помогает.В чем здесь проблема?

1 Ответ

0 голосов
/ 01 декабря 2018

Когда вы заполняете сеть с помощью Tensor, градиенты не рассчитываются по умолчанию.Чтобы это работало, вы можете заключить ваш FloatTensor в torch.autograd.Variable или установить свойство requires_grad тензорного. Вот пример.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...