Пыторч Н.Н. и общение между классами - PullRequest
0 голосов
/ 06 января 2019

Я новичок в Python и Pytorch, и у меня проблемы с пониманием, как это работает.

    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim

    class Net(nn.Module):
        def __init__(self):
            ..
        def forward(self, x):
            ..
            return x
    net = Net()


criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

enter image description here

Итак, это код, и я рисую то, что я понимаю, из кода на картинке. У меня есть несколько вопросов:

A) Почему я не могу использовать nn.CrossEntropy вместо «критерия» непосредственно в коде? Какая разница, если я назначу его переменной? Я получаю эту ошибку: bool значение Tensor с более чем одним значением неоднозначно

B) Почему, когда класс Net получает объект (nn) (я предполагал, что когда используется «как», объект создается), тогда класс Net может просто использовать в обратном направлении впоследствии? Он должен быть частью nn, а не Net. Не могли бы вы понять это для меня?

C) Хотя optim - это другой объект, как параметры, оптимизированные с помощью optim, могут влиять на nn? Я не понимаю, как они передают переменные и обновляют друг друга?

1 Ответ

0 голосов
/ 07 января 2019

A) Установив его как переменную в одном месте, это помогает упростить изменение функции потерь в одном месте, а не вводить nn.MSELoss во многих местах, поскольку код увеличивается в размере и сложности. Менее вероятно, чтобы делать ошибки в принципе.
Что касается ошибки, то для ответа на эту ошибку bool потребуется больше информации. На какой строке, какие входы и т. Д. Слишком мало информации, чтобы помочь там.

B) Net (nn.Module) наследуется от nn.Module, который добавит в обратную сторону все операции, которые вы добавляете в класс. См. документы для получения дополнительной информации.

C) «сеть» - это объект. net.parameters () - это итератор, который перебирает все параметры в сетевом объекте. Поэтому он передается по ссылке, а не по параметрам.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...