Я хочу обучить свою сеть в PyTorch партиями переменного размера.В классификации изображений, например, я хотел бы обучить сеть с первой партией изображений размером 128 на 128 ([64,3128,128]) и второй партией изображений размером 256 на 256 ([64,325,556]).В более общем случае в каждой партии я хотел бы произвольно обрезать (изменить размер) все изображения в пакете до размера между некоторым минимальным и максимальным значением.Однако функция преобразования работает только на отдельных примерах, а не на целых пакетах.
Я просмотрел форумы PyTorch и обнаружил, что этого можно добиться с помощью аргумента collate_fn
в DataLoader.Однако, если я прав, collate_fn
возвращает список семплов, а не тензор.Интересно, приводит ли это к более медленному времени обучения, поскольку графический процессор не может обрабатывать весь пакет сразу?
Каков наилучший способ сделать это в PyTorch без необходимости дополнять мой ввод?