Pytorch Вопрос из «Обучение глубокому укреплению: практический» - PullRequest
2 голосов
/ 02 октября 2019

Я читаю «Уроки глубокого обучения» Максима Лапана. Я столкнулся с этим кодом в главе 2, и я не понимаю несколько вещей. Кто-нибудь может объяснить, почему вывод print (out) дает три параметра вместо одного введенного нами тензора с плавающей точкой. Кроме того, зачем здесь нужна суперфункция? Наконец, какой параметр x принимает forward? Спасибо.

class OurModule(nn.Module):
    def __init__(self, num_inputs, num_classes, dropout_prob=0.3):  #init 
        super(OurModule, self).__init__() #Call OurModule and pass the net instance (Why is this necessary?) 
        self.pipe = nn.Sequential( #net.pipe is the nn object now
            nn.Linear(num_inputs, 5),
            nn.ReLU(),
            nn.Linear(5, 20),
            nn.ReLU(),
            nn.Linear(20, num_classes),
            nn.Dropout(p=dropout_prob),
            nn.Softmax(dim=1)
        )

    def forward(self, x): #override the default forward method by passing it our net instance and (return the nn object?). x is the tensor? This is called when 'net' receives a param?
        return self.pipe(x)

if __name__ == "__main__":
    net = OurModule(num_inputs=2, num_classes=3)
    print(net)
    v = torch.FloatTensor([[2, 3]])
    out = net(v)
    print(out) #[2,3] put through the forward method of the nn? Why did we get a third param for the output?
    print("Cuda's availability is %s" % torch.cuda.is_available()) #find if gpu is available
    if torch.cuda.is_available():
        print("Data from cuda: %s" % out.to('cuda'))

OurModule.__mro__

Ответы [ 2 ]

3 голосов
/ 02 октября 2019

OurModule определил PyTorch nn.Module, который принимает 2 входы (num_inputs) и производит 3 выходов (num_classes).

Он состоит из:

  1. A Linear слои, которые принимают 2 входы и производят 5 выходы
  2. A ReLU
  3. A Linear слой, который принимает 5 входы и производит20 выходы
  4. A ReLU
  5. A Linear слой, который принимает 20 входы и производит 3 (num_classes) выходов
  6. A Dropout layer
  7. A Softmax layer

Вы создаете v, который состоит из 2 входов и пропускаете его через метод forward() этой сети при вызове net(v). Результат работы этой сети (3 выходы) затем сохраняется в out.

В вашем примере x принимает значение v, torch.FloatTensor([[2, 3]])

2 голосов
/ 02 октября 2019

Хотя @JoshVarty дал отличный ответ, я хотел бы добавить немного.

почему здесь необходима суперфункция

Класс OurModule наследует nn.Module. Суперфункция означает, что вы хотите использовать родительскую функцию (nn.Module), а именно init. Вы можете обратиться к исходному коду , чтобы увидеть, что именно родительский класс делает в функции init.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...