Ошибка pad_sequence: устранение неполадок «Расширенный размер тензорного () должен соответствовать существующему размеру () в не-одиночном измерении 2». - PullRequest
0 голосов
/ 13 января 2020

Я пытаюсь добавить collate_fn в мой загрузчик данных для загрузки изображений переменного размера. Но он продолжает выдавать сообщение об ошибке: «Увеличенный размер тензора (1444) должен соответствовать существующему размеру (1936) при не синглтонном измерении 2. Размеры цели: [3, 1444, 1444]. Размеры тензор: [ 3, 1296, 1936] ", что означает, что что-то не работает правильно с pad_sequence. Любая помощь будет высоко ценится.

#collate function
def my_collate(batch):
    # batch contains a list of tuples of structure (sequence, target)
    targets = [item[1]  for item in batch] # list of labels
    data = [item[0] for item in batch]
    data = pad_sequence(data, batch_first= False)
    return [data, targets]

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]) 
#training
trainset = torchvision.datasets.ImageFolder(root='/content/output/train', transform = transform, target_transform=None) #transform here can crop image
trainloader = torch.utils.data.DataLoader(trainset, batch_size = 8,shuffle = False, num_workers = 0,collate_fn=my_collate,pin_memory=True)

RuntimeError                              Traceback (most recent call last)
<ipython-input-77-0efe18aac68b> in <module>()
     10   correct = 0
     11   total = 0
---> 12   for i, data in enumerate(trainloader, 0):
     13       inputs ,labels = data #get inputs
     14       # inputs = torch.FloatTensor(inputs)

3 frames
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    344     def __next__(self):
    345         index = self._next_index()  # may raise StopIteration
--> 346         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    347         if self._pin_memory:
    348             data = _utils.pin_memory.pin_memory(data)

/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     45         else:
     46             data = self.dataset[possibly_batched_index]
---> 47         return self.collate_fn(data)

<ipython-input-76-d9f902bb6b0f> in my_collate(batch)
     15     print(len(data[0][1][0]))
     16     print(len(data[0][2]))
---> 17     data = pad_sequence(data, batch_first= False)
     18     return [data, targets]
     19 

/usr/local/lib/python3.6/dist-packages/torch/nn/utils/rnn.py in pad_sequence(sequences, batch_first, padding_value)
    389             out_tensor[i, :length, ...] = tensor
    390         else:
--> 391             out_tensor[:length, i, ...] = tensor
    392 
    393     return out_tensor

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...