Начинающий PyTorch - RuntimeError: фигура '[16, 400]' недопустима для ввода размером 9600 - PullRequest
0 голосов
/ 27 февраля 2020

Я пытаюсь построить CNN, но получаю эту ошибку:

---> 52         x = x.view(x.size(0), 5 * 5 * 16)
RuntimeError: shape '[16, 400]' is invalid for input of size 9600

Мне не ясно, какими должны быть входные данные строки 'x.view'. Кроме того, я не совсем понимаю, сколько раз я должен иметь эту функцию «x.view» в моем коде. Это только один раз, после 3 сверточных слоев и 2 линейных слоев? Или это 5 раз, один за каждым слоем?

Вот мой код:

CNN

import torch.nn.functional as F

# Convolutional neural network
class ConvNet(nn.Module):

    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()

        self.conv1 = nn.Conv2d(
            in_channels=3, 
            out_channels=16, 
            kernel_size=3)

        self.conv2 = nn.Conv2d(
            in_channels=16, 
            out_channels=24, 
            kernel_size=4)

        self.conv3 = nn.Conv2d(
            in_channels=24, 
            out_channels=32, 
            kernel_size=4)

        self.dropout = nn.Dropout2d(p=0.3)

        self.pool = nn.MaxPool2d(2)

        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(512, 10)

        self.final = nn.Softmax(dim=1)

    def forward(self, x):

        print('shape 0 ' + str(x.shape))

        x = F.max_pool2d(F.relu(self.conv1(x)), 2)  
        x = self.dropout(x)

        print('shape 1 ' + str(x.shape))

        x = F.max_pool2d(F.relu(self.conv2(x)), 2)  
        x = self.dropout(x)

        print('shape 2 ' + str(x.shape))

        # x = F.max_pool2d(F.relu(self.conv3(x)), 2)  
        # x = self.dropout(x)

        x = F.interpolate(x, size=(5, 5))  
        x = x.view(x.size(0), 5 * 5 * 16)

        x = self.fc1(x) 

        return x

net = ConvNet()

Может кто-нибудь помочь мне понять проблему?

Вывод 'x.shape' таков:

Факел формы 0.Размер ([16, 3, 256, 256])

Факел формы 1.Размер ([16, 16, 127, 127])

факел формы 2.Размер ([16, 24, 62, 62])

Спасибо

1 Ответ

1 голос
/ 27 февраля 2020

Это означает, что вместо этого произведение канала и пространственных измерений не равно 5*5*16. Чтобы сгладить тензор, замените x = x.view(x.size(0), 5 * 5 * 16) на:

x = x.view(x.size(0), -1)

и self.fc1 = nn.Linear(600, 120) на:

self.fc1 = nn.Linear(600, 120)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...