Проверьте ответ здесь .
Я также разместил его ниже.
================================================= ==========
Данные читаются с использованием ImageFolder
. Задача - двоичная классификация изображений с 498 изображениями в наборе данных, которые равномерно распределены между обоими классами (249 изображений каждый).
img_dataset = ImageFolder(..., transforms=t)
1. SubsetRandomSampler
dataset_size = len(img_dataset)
dataset_indices = list(range(dataset_size))
np.random.shuffle(dataset_indices)
val_split_index = int(np.floor(0.2 * dataset_size))
train_idx, val_idx = dataset_indices[val_split_index:], dataset_indices[:val_split_index]
train_sampler = SubsetRandomSampler(train_idx)
val_sampler = SubsetRandomSampler(val_idx)
train_loader = DataLoader(dataset=img_dataset, shuffle=False, batch_size=8, sampler=train_sampler)
validation_loader = DataLoader(dataset=img_dataset, shuffle=False, batch_size=1, sampler=val_sampler)
2. random_split
Здесь из общего числа 498 изображений 400 случайным образом назначаются для обучения, а остальные 98 - для проверки.
dataset_train, dataset_valid = random_split(img_dataset, (400, 98))
train_loader = DataLoader(dataset=dataset_train, shuffle=True, batch_size=8)
val_loader = DataLoader(dataset=dataset_valid, shuffle=False, batch_size=1)
3. WeightedRandomSampler
если кто-то споткнулся здесь в поисках WeightedRandomSampler
, проверьте ответ @ ptrblck здесь для хорошего объяснения того, что происходит ниже.
Теперь, как WeightedRandomSampler
подходит для создания набора train + val? Потому что, в отличие от SubsetRandomSampler
или random_split()
, мы не разделяем на поезд и вал здесь. Мы просто гарантируем, что каждая партия получает одинаковое количество классов во время обучения.
Итак, я думаю, нам нужно использовать WeightedRandomSampler
после random_split()
или SubsetRandomSampler
. Но это не гарантирует, что train и val имеют одинаковое соотношение между классами.
target_list = []
for _, t in imgdataset:
target_list.append(t)
target_list = torch.tensor(target_list)
target_list = target_list[torch.randperm(len(target_list))]
# get_class_distribution() is a function that takes in a dataset and
# returns a dictionary with class count. In this case, the
# get_class_distribution(img_dataset) returns the following -
# {'class_0': 249, 'class_0': 249}
class_count = [i for i in get_class_distribution(img_dataset).values()]
class_weights = 1./torch.tensor(class_count, dtype=torch.float)
class_weights_all = class_weights[target_list]
weighted_sampler = WeightedRandomSampler(
weights=class_weights_all,
num_samples=len(class_weights_all),
replacement=True
)