Проблемы измерения пользовательских DataLoader PyTorch для CNN - PullRequest
0 голосов
/ 27 марта 2020

Я написал собственный набор данных и загрузчик данных для проекта PyTorch CNN. Вот соответствующий код для набора данных

class MyDataset(Dataset):

  def __init__(self): 
    pass

  def __len__(self):
    return COUNT

  def __getitem__(self, idx):
    x, y = X[idx], Y[idx]
    x = image_augment(x)  # custom func to resize image to 32x32
    return x, y

Форма каждой тренировки x равна [4, 32, 32, 3].

А вот мой Net код, взятый непосредственно из этого примера PyTorch .

class Net(nn.Module):

    def __init__(self, nc):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, nc)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

Когда я пытаюсь обучить этому net на моих данных из мой DataLoader, я получаю сообщение об ошибке Given groups=1, weight of size [6, 3, 5, 5], expected input[4, 32, 32, 3] to have 3 channels, but got 200 channels instead. Мне кажется, моя проблема связана с формой моих данных, поступающих из моего DataLoader с использованием x.view(4, 3, 32, 32), но затем я получил сообщение об ошибке I couldn't use Conv2D on a ByteTensor. Я немного растерялся и был бы очень признателен за любую помощь. Спасибо!

1 Ответ

0 голосов
/ 27 марта 2020

Я понял это в конце концов. Пришлось x = x.view(x.shape[0], 3, self.img_height, self.img_width).type('torch.FloatTensor'). например. Это сделало бы этот обмен с [4, 32, 32, 3] на [4, 3, 32, 32].

...