Я в основном сортирую свои изображения CNN в список с четной и нечетной индексацией. Четный индекс будет иметь положительные изображения, а нечетный индекс будет иметь отрицательные изображения. Вот мой код:
from PIL import Image
import matplotlib.pyplot as plt
import os
import glob
import torch
from torch.utils.data import Dataset
def show_data(data_sample, shape = (28, 28)):
plt.imshow(data_sample[0].numpy().reshape(shape), cmap='gray')
plt.title('y = ' + data_sample[1])
directory="/resources/data"
negative='Negative'
negative_file_path=os.path.join(directory,negative)
negative_files=[os.path.join(negative_file_path,file) for file in os.listdir(negative_file_path) if file.endswith(".jpg")]
negative_files.sort()
negative_files[0:3]
positive="Positive"
positive_file_path=os.path.join(directory,positive)
positive_files=[os.path.join(positive_file_path,file) for file in os.listdir(positive_file_path) if file.endswith(".jpg")]
positive_files.sort()
positive_files[0:3]
n = len(negative_files)
p = len(positive_files)
number_of_samples = n + p
print(number_of_samples)
Y=torch.zeros([number_of_samples])
Y=Y.type(torch.LongTensor)
Y.type()
Y[::2]=1
Y[1::2]=0