Привет, я новичок в pytorch. Я пытался реализовать обучение с помощью CIFAR10 и встроенной модели resnet18. Для этого я собираюсь сначала скачать оригинальный набор данных и применить к нему некоторые преобразования, а также взять 500 выборок из каждого класса из 10 классов и создать новыйнабор данных с 5000 всего обучающих выборок (вместо 50 000 выборок в исходном CIFAR10).
В этом случае я просто извлекаю и упорядочиваю индексы выборок набора в соответствии с их классом, а затем случайным образом выбираю 500их для каждого класса. Но с помощью этого метода я не могу добиться какой-либо производительности от сети, даже если я делаю дополнения данных в исходном наборе. Так есть ли способ действительно получить данные 500 / на класс (используя некоторый пользовательский загрузчик данных), а затем применить к ним увеличение данных для достижения производительности?
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary
transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomPerspective(),
transforms.RandomAffine(10),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010))])
trainset = torchvision.datasets.CIFAR10(root = './data', train=True, download = True, transform = transform)
for i,target in enumerate(trainset.targets):
sample_dic[target].append(i)
new_dataset = dict((e1, []) for e1 in range(10))
for key in sample_dic.keys():
new_dataset[key] = random.sample(sample_dic[key], 500)
training_indices = [new_dataset[key] for key in new_dataset.keys()]
training_indices = np.asarray(training_indices).flatten()
trainset.data = trainset.data[training_indices]
trainloader = torch.utils.data.DataLoader(trainset, batch_size =4, shuffle=True, num_workers=2)
Я использую Adam Optimizer с lr = 0,001и weight_decay (параметр регуляризации L2) 1e-5. Я разлагаю lr наполовину после каждых 10 эпох. Я ожидаю увидеть хотя бы некоторое улучшение производительности с точки зрения точности теста. Но точность застревает на уровне около 10% и просто колеблется вокруг нее. У вас есть рекомендации, где я делаю неправильно?