У меня проблемы с написанием пользовательской функции collate_fn
для класса PyTorch DataLoader
. Мне нужна пользовательская функция, потому что мои входы имеют разные размеры.
В настоящее время я пытаюсь написать базовую реализацию Stanford MURA paper . Набор данных имеет набор помеченных исследований. Исследование может содержать более одного изображения. Я создал собственный класс Dataset
, который складывает эти несколько изображений, используя torch.stack
.
Сложенный тензор затем предоставляется в качестве входных данных для модели, и список выходных данных усредняется для получения одного выходного сигнала. Эта реализация прекрасно работает с DataLoader
, когда batch_size=1
. Однако, когда я пытаюсь установить batch_size
равным 8, как в случае с оригинальной статьей, DataLoader
завершается неудачно, так как он использует torch.stack
для укладки пакета, а входы в моем пакете имеют переменные размеры (так как каждый исследование может иметь несколько изображений).
Чтобы это исправить, я попытался реализовать свою пользовательскую функцию collate_fn
.
def collate_fn(batch):
imgs = [item['images'] for item in batch]
targets = [item['label'] for item in batch]
targets = torch.LongTensor(targets)
return imgs, targets
Затем в цикле тренировочной эпохи я повторяю каждую партию следующим образом:
for image, label in zip(*batch):
label = label.type(torch.FloatTensor)
# wrap them in Variable
image = Variable(image).cuda()
label = Variable(label).cuda()
# forward
output = model(image)
output = torch.mean(output)
loss = criterion(output, label, phase)
Тем не менее, это не дает мне каких-либо улучшенных таймингов в эпоху и все равно занимает столько же времени, сколько и при размере партии всего 1. Я также пытался установить размер партии равным 32, и это не улучшает время либо.
Я что-то не так делаю?
Есть ли лучший подход к этому?