Как создать пользовательский набор данных для задачи классификации, когда классом называется имя папки с помощью Pytorch? - PullRequest
0 голосов
/ 19 июня 2019

Проблема в том, что dataloader возвращает неправильный класс для соответствующего изображения? например, если я печатаю class_to_idx из train_loader, когда размер пакета равен 1, я ожидал получить один класс на пакет, но в настоящее время он возвращает все классы, что составляет 15 классов на изображение.

В этом случае классы представляют собой класс папок (все изображения в одной папке принадлежат одному классу)

фрагмент здесь: (это функция для возврата класса из имени папки dir)

import os
def find_classes(dir):   # Finds the class folders in a dataset, dir (string): Root directory path.

        classes = [d.name for d in os.scandir(dir) if d.is_dir()]
        classes.sort()
        class_to_idx = {classes[i]: i for i in range(len(classes))}
        return classes, class_to_idx

вот основной фрагмент для создания пользовательского набора данных и кодировщика данных def main ():

class CustomDataset(Dataset):    

    def __init__(self, image_paths, classes, class_to_id):  

        self.image_paths = image_paths
        self.transforms = transforms.ToTensor() 
        classes, class_to_id = find_classes('D:/Neda/Echo_View_Classification/avi_images/')
        self.classes = classes
        self.class_to_idx = class_to_idx

    def __getitem__(self, index):

        image = Image.open(self.image_paths[index])
        t_image = image.convert('L')
        t_image = self.transforms(t_image)    

        class_to_idx = self.class_to_idx

        return t_image, class_to_idx, self.image_paths[index]

    def __len__(self): 

        return len(self.image_paths)


folder_data = glob.glob("D:\\Neda\\Echo_View_Classification\\avi_images\\*\\*.png") # no augmnetation
#numpy.savetxt('distribution_class.csv', numpy.c_[folder_data], fmt=['%s'], comments='', delimiter = ",")                    

 #split these path using a certain percentage
len_data = len(folder_data)
print("count of dataset: ", len_data)

split_1 = int(0.6 * len(folder_data))
split_2 = int(0.8 * len(folder_data))

folder_data.sort()

train_image_paths = folder_data[:split_1]
print("count of train images is: ", len(train_image_paths)) 
numpy.savetxt('im_training_path_1.csv', numpy.c_[train_image_paths], fmt=['%s'], comments='', delimiter = ",")                    


valid_image_paths = folder_data[split_1:split_2]
print("count of validation image is: ", len(valid_image_paths))
numpy.savetxt('im_valid_path_1.csv', numpy.c_[valid_image_paths], fmt=['%s'], comments='', delimiter = ",")     


test_image_paths = folder_data[split_2:]
print("count of test images is: ", len(test_image_paths)) 
numpy.savetxt('im_testing_path_1.csv', numpy.c_[test_image_paths], fmt=['%s'], comments='', delimiter = ",")                    

classes = ['1_PLAX_1_PLAX_full',
  '1_PLAX_2_PLAX_valves',
  '1_PLAX_4_PLAX_TV',
  '2_PSAX_1_PSAX_AV',
  '2_PSAX_2_PSAX_LV',
  '3_Apical_1_MV_LA_IAS',
  '3_Apical_2_A2CH',
  '3_Apical_3_A3CH',
  '3_Apical_5_A5CH',
  '4_A4CH_1_A4CH_LV',
  '4_A4CH_2_A4CH_RV',
  '4_Subcostal_1_Subcostal_heart',
  '4_Subcostal_2_Subcostal_IVC',
  'root_5_Suprasternal',
  'root_6_OTHER']


class_to_idx = {'1_PLAX_1_PLAX_full': 0,
  '1_PLAX_2_PLAX_valves': 1,
  '1_PLAX_4_PLAX_TV': 2,
  '2_PSAX_1_PSAX_AV': 3,
  '2_PSAX_2_PSAX_LV': 4,
  '3_Apical_1_MV_LA_IAS': 5,
  '3_Apical_2_A2CH': 6,
  '3_Apical_3_A3CH': 7,
  '3_Apical_5_A5CH': 8,
  '4_A4CH_1_A4CH_LV': 9,
  '4_A4CH_2_A4CH_RV': 10,
  '4_Subcostal_1_Subcostal_heart': 11,
  '4_Subcostal_2_Subcostal_IVC': 12,
  'root_5_Suprasternal': 13,
  'root_6_OTHER': 14}


train_dataset = CustomDataset(train_image_paths, class_to_idx, classes)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers=0)

valid_dataset = CustomDataset(valid_image_paths,  class_to_idx, classes)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=0)


test_dataset = CustomDataset(test_image_paths, class_to_idx, classes)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)  

dataLoaders = {
        'train': train_loader,
        'valid': valid_loader,
         'test': test_loader,
        }

1 Ответ

0 голосов
/ 19 июня 2019

Я думаю ImageFolder из torchvision.datasets поможет вам в загрузке ваших данных.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...