Эффективная функция PyTorch DataLoader collate_fn для входов различных измерений - PullRequest
0 голосов
/ 08 января 2019

У меня проблемы с написанием пользовательской функции collate_fn для класса PyTorch DataLoader. Мне нужна пользовательская функция, потому что мои входы имеют разные размеры.

В настоящее время я пытаюсь написать базовую реализацию Stanford MURA paper . Набор данных имеет набор помеченных исследований. Исследование может содержать более одного изображения. Я создал собственный класс Dataset, который складывает эти несколько изображений, используя torch.stack.

Сложенный тензор затем предоставляется в качестве входных данных для модели, и список выходных данных усредняется для получения одного выходного сигнала. Эта реализация прекрасно работает с DataLoader, когда batch_size=1. Однако, когда я пытаюсь установить batch_size равным 8, как в случае с оригинальной статьей, DataLoader завершается неудачно, так как он использует torch.stack для укладки пакета, а входы в моем пакете имеют переменные размеры (так как каждый исследование может иметь несколько изображений).

Чтобы это исправить, я попытался реализовать свою пользовательскую функцию collate_fn.

def collate_fn(batch):
    imgs = [item['images'] for item in batch]
    targets = [item['label'] for item in batch]
    targets = torch.LongTensor(targets)
    return imgs, targets

Затем в цикле тренировочной эпохи я повторяю каждую партию следующим образом:

for image, label in zip(*batch):
    label = label.type(torch.FloatTensor)
    # wrap them in Variable
    image = Variable(image).cuda()  
    label = Variable(label).cuda()
    # forward
    output = model(image)
    output = torch.mean(output)
    loss = criterion(output, label, phase)

Тем не менее, это не дает мне каких-либо улучшенных таймингов в эпоху и все равно занимает столько же времени, сколько и при размере партии всего 1. Я также пытался установить размер партии равным 32, и это не улучшает время либо.

Я что-то не так делаю? Есть ли лучший подход к этому?

...