Пользовательский загрузчик данных возвращает список в pytorch - PullRequest
0 голосов
/ 17 июня 2020

Я хочу получить 3 пакета изображений из 3 разных папок. Я написал собственный загрузчик данных в pytorch. но он возвращает список, в котором есть все пакеты, а не один пакет за раз. (работает в google colab)

#custom data loader
class set(Dataset):
    def __init__(self, dataset_input, dataset_expertA, dataset_expertB):
        self.dataset1 = dataset_input
        self.dataset2 = dataset_expertA
        self.dataset3 = dataset_expertB

    def __getitem__(self, index):
        x1 = self.dataset1[index]
        x2 = self.dataset2[index]
        x3 = self.dataset3[index]

        return x1, x2, x3

    def __len__(self):
        return len(self.dataset1)

input_path = "/content/gdrive/My Drive/project/input/"

dataset = datasets.ImageFolder(root= input_path, transform=transforms.Compose([
                               transforms.Resize([64,64]),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                               ]))

expertA_path = "/content/gdrive/My Drive/project/expertA/"

datasetA = datasets.ImageFolder(root= expertA_path, transform=transforms.Compose([
                               transforms.Resize([64,64]),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                               ]))


expertB_path = "/content/gdrive/My Drive/project/expertB/"

datasetB = datasets.ImageFolder(root= expertB_path, transform=transforms.Compose([
                               transforms.Resize([64,64]),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                               ]))


data = set(dataset, datasetA, datasetB)
dataloader = torch.utils.data.DataLoader(data, batch_size=64,
                                         shuffle=True, num_workers=2)


for i, (inp, expA, expB) in enumerate(dataloader):

  print(inp.shape)
  break

это печатает ошибку, что inp является списком, и когда я печатаю (inp [0]. shape) я получаю правильную форму, я думаю, что inp содержит все пакеты ie inp [0], inp [1] ...

какую ошибку я делаю в коде загрузчика данных?

1 Ответ

2 голосов
/ 17 июня 2020

datasets.ImageFolder возвращает кортеж (изображение, метка) , следовательно, inp также является кортежем, где inp[0] - изображения, а inp[1] - их соответствующие этикетки. То же самое относится к expA и expB.

Если вам нужны только изображения без меток, вы можете игнорировать метки и просто возвращать изображения при доступе к данным в вашем пользовательском наборе данных:

def __getitem__(self, index):
    image1, label1 = self.dataset1[index]
    image2, label2 = self.dataset2[index]
    image3, label3 = self.dataset3[index]

    return image1, image2, image3
...