PyTorch - неправильная маркировка с использованием torchvision.datasets.ImageFolder - PullRequest
0 голосов
/ 12 декабря 2018

Я структурировал свой набор данных следующим образом:

dataset/train/0/456.jpg
dataset/train/1/456456.jpg
dataset/train/2/456.jpg
dataset/train/...

dataset/val/0/878.jpg
dataset/val/1/234.jpg
dataset/val/2/34554.jpg
dataset/val/...

Поэтому я использовал torchvision.datasets.ImageFolder, чтобы импортировать мой набор данных в PyTorch.Тем не менее, кажется, что это не дает правильную метку на правильное изображение.Я добавил свой код ниже:

data_transforms = {
    'train': transforms.Compose(
        [transforms.Resize((176,176)),
         transforms.RandomRotation((0,360)),
         transforms.RandomHorizontalFlip(),
         transforms.RandomVerticalFlip(),
         transforms.CenterCrop(128),         
         transforms.Grayscale(),
         transforms.ToTensor(),
         transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
    ]),
    'val': transforms.Compose(
        [transforms.Resize((128,128)),
         transforms.Grayscale(),
         transforms.ToTensor(),
         transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
    ]),
}

data_dir = 'dataset'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Я обнаружил, что метки неверны, используя следующую функцию:

def imshow(img):
    img = img / 2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

dataiter = iter(dataloaders['val'])
images, labels = dataiter.next()

imshow(torchvision.utils.make_grid(images))
print(labels)

Используя показанные изображения и метки, я проверил вручнуюправильны ли они.К сожалению, ярлыки не соответствуют изображениям.Может кто-нибудь сказать мне, что я делаю не так?

Ответы [ 2 ]

0 голосов
/ 12 декабря 2018

Кто-то помог мне с этим.ImageFolder создает свои собственные внутренние метки.Напечатав image_datasets['train'].class_to_idx, вы можете увидеть, какая этикетка связана с какой внутренней этикеткой.Используя этот словарь, вы можете отследить оригинальную метку.

0 голосов
/ 12 декабря 2018

API ImageFolder предполагает, что ваши данные находятся в «предопределенной» структуре папок.Пожалуйста, ознакомьтесь с приведенным ниже комментарием к коду PyTorch или документации @ https://pytorch.org/docs/stable/torchvision/datasets.html#imagefolder

A generic data loader where the images are arranged in this way: ::

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

Это означает, что вам необходимо расположить данные в папках, соответствующих вашим ярлыкам.В приведенном выше случае есть 2 ярлыка, кошки и собаки.

Надеюсь, это поможет!

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...