Проблемы передачи тензора в линейный слой - Pytorch - PullRequest
0 голосов
/ 07 марта 2019

Я пытаюсь построить нейронную сеть, однако я не могу понять, где я ошибаюсь с максимальным слоем пула.

    self.embed1 = nn.Embedding(256, 8)
    self.conv_1 = nn.Conv2d(1, 64, (7,8), padding = (0,0))
    self.fc1 = nn.Linear(64, 2)

def forward(self,x):

    import pdb; pdb.set_trace()
    x = self.embed1(x) #input a tensor of ([1,217]) output size: ([1, 217, 8]) 
    x = x.unsqueeze(0) #conv lay needs a tensor of size (B x C x W x H) so unsqueeze here to make ([1, 1, 217, 8])
    x = self.conv_1(x) #creates 64 filter of size (7, 8).Outputs ([1, 64, 211, 1]) as 6 values lost due to not padding. 

    x = torch.max(x,0) #returning max over the 64 columns. This returns a tuple of length 2 with 64 values in each att, the max val and indices.
    x = x[0] #I only need the max values. This returns a tensor of size ([64, 211, 1])
    x = x.squeeze(2) #linear layer only wants the number of inputs and number of outputs so I squeeze the tensor to ([64, 211])
    x = self.fc1(x) #Error Size mismatch (M1: [64 x 211] M2: [64 x 2])

Я понимаю, почему линейный слой не принимает 211, однако я не понимаю, почему мой тензор после максимизации по столбцам не равен 64 x 2.

Ответы [ 2 ]

0 голосов
/ 07 марта 2019

Если я правильно угадываю ваши намерения, ваша ошибка в том, что вы используете torch.max для 2d maxpooling вместо torch.nn.functional.max_pool2d.Первая уменьшается по тензорному измерению (например, по всем картам объектов или всем горизонтальным линиям), тогда как последняя уменьшается в каждой квадратной пространственной окрестности в плоскости [h, w] [batch, features, h, w] тензор.

0 голосов
/ 07 марта 2019

Вы используете torch.max, возвращает два выхода: максимальное значение по dim = 0 и argmax по этому измерению.Таким образом, вам нужно выбрать только первый выход.(вы можете рассмотреть возможность использования адаптивного максимального пула для этой задачи).

Ваш линейный слой ожидает, что его вход имеет dim 64 (то есть batch_size -by- 64 в форметензор).Тем не менее, кажется, что ваш x[0] имеет форму 13504 x 1 - определенно не 64.

См., Например, этот поток .

...