Тензор Pytorch, как переключать положение канала - Ошибка времени выполнения - PullRequest
3 голосов
/ 08 января 2020

У меня есть тренировочный набор данных, как показано ниже, где X_train - это 3D с 3 каналами

Форма X_Train: (708, 256, 3) Форма Y_Train: (708, 4)

Затем я конвертирую их в тензор и вводим в загрузчик данных:

X_train=torch.from_numpy(X_data)
y_train=torch.from_numpy(y_data)
training_dataset = torch.utils.data.TensorDataset(X_train, y_train)
train_loader = torch.utils.data.DataLoader(training_dataset, batch_size=50, shuffle=False)

Однако при обучении модели я получаю следующую ошибку: RuntimeError: Данные группы = 1, вес размера 24 3 5, ожидаемый ввод [708, 256, 3] иметь 3 канала, но вместо них получили 256 каналов

Полагаю, это связано с положением канала? В Tensorflow позиция канала находится в конце, но в PyTorch формат выглядит так: «Размер партии x Канал x Высота x Ширина»? Так как же поменять местами позиции в тензоре x_train для соответствия ожидаемому формату в загрузчике данных?

class TwoLayerNet(torch.nn.Module):
    def __init__(self):
        super(TwoLayerNet,self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv1d(3, 3*8, kernel_size=5, stride=1),  
            nn.Sigmoid(),
            nn.AvgPool1d(kernel_size=2, stride=0))
        self.conv2 = nn.Sequential(
            nn.Conv1d(3*8, 12, kernel_size=5, stride=1),
            nn.Sigmoid(),
            nn.AvgPool1d(kernel_size=2, stride = 0))
        #self.drop_out = nn.Dropout()

        self.fc1 = nn.Linear(708, 732) 
        self.fc2 = nn.Linear(732, 4)

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = out.reshape(out.size(0), -1)
        out = self.drop_out(out)
        out = self.fc1(out)
        out = self.fc2(out)
        return out

1 Ответ

3 голосов
/ 08 января 2020

Использование permute.

X_train = torch.rand(708, 256, 3)
X_train = X_train.permute(2, 0, 1)
X_train.shape
# => torch.Size([3, 708, 256])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...