Ваши данные не масштабируются и не нормализуются. Если вы посмотрите на переменную images
в цикле тренировки, то это значение между 0 и 255. Это, по всей вероятности, повредит вашему тренировочному процессу.
Существуют более чистые способы подбора выборки набора данных, как вы хотите, но без внесения изменений. большая часть вашего кода, используя это определение загрузки данных
import torchvision.transforms as transforms
#Load Dataset
preprocessing = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = dsets.MNIST(root='./data', train=True, transform=preprocessing, download=True)
#Filter samples by label (to get binary classification) and by number of training samples
Binary_filter=torch.add(train_dataset.targets==1, train_dataset.targets==0)
train_dataset.data, train_dataset.targets = train_dataset.data[Binary_filter],train_dataset.targets[Binary_filter]
TrainSet_filter=torch.cat((torch.ones(num_of_training_samples)
,torch.zeros(len(train_dataset.targets)-num_of_training_samples)),0).bool()
train_dataset.data, train_dataset.targets = train_dataset.data[TrainSet_filter], train_dataset.targets[TrainSet_filter]
#Make Dataset Iterable
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
У меня ~ 100% точности примерно за 5-10 эпох.