Я пытаюсь настроить классификатор изображений с помощью Pytorch.Мои образцы изображений имеют 4 канала и имеют размер 28x28 пикселей.Я пытаюсь использовать встроенный torchvision.models.inception_v3 () в качестве моей модели.Всякий раз, когда я пытаюсь запустить свой код, я получаю эту ошибку:
RuntimeError: Расчетный размер дополненного ввода для канала: (1 x 1).Размер ядра: (3 х 3).Размер ядра не может превышать фактический размер ввода в /opt/conda/conda-bld/pytorch_1524584710464/work/aten/src/THNN/generic/SpatialConvolutionMM.c:48
Я не могуузнайте, как изменить размер входного сигнала на канал или выясните, что означает ошибка.Я полагаю, что мне нужно изменить размер входных данных для каждого канала, так как я не могу редактировать размер ядра в готовой модели.
Я пробовал заполнение, но это не помогло.Вот сокращенная часть моего кода, которая выдает ошибку, когда я вызываю train ():
import torch
import torchvision as tv
import torch.optim as optim
from torch import nn
from torch.utils.data import DataLoader
model = tv.models.inception_v3()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=0)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.9)
trn_dataset = tv.datasets.ImageFolder(
"D:/tests/classification_test_data/trn",
transform=tv.transforms.Compose([tv.transforms.RandomRotation((0,275)), tv.transforms.RandomHorizontalFlip(),
tv.transforms.ToTensor()]))
trn_dataloader = DataLoader(trn_dataset, batch_size=32, num_workers=4, shuffle=True)
for epoch in range(0, 10):
train(trn_dataloader, model, criterion, optimizer, lr_scheduler, 6, 32)
print("End of training")
def train(train_loader, model, criterion, optimizer, scheduler, num_classes, batch_size):
model.train()
scheduler.step()
for index, data in enumerate(train_loader):
inputs, labels = data
optimizer.zero_grad()
outputs = model(inputs)
outputs_flatten = flatten_outputs(outputs, num_classes)
loss = criterion(outputs_flatten, labels)
loss.backward()
optimizer.step()
def flatten_outputs(predictions, number_of_classes):
logits_permuted = predictions.permute(0, 2, 3, 1)
logits_permuted_cont = logits_permuted.contiguous()
outputs_flatten = logits_permuted_cont.view(-1, number_of_classes)
return outputs_flatten