Условный GAN - тасует / разделяет два набора данных одинаково - PullRequest
1 голос
/ 22 апреля 2019

Я пытаюсь использовать поезд DCGAN для раскрашивания некоторых изображений.При этом я настраиваю свой GAN на версии изображений в градациях серого.Затем я хочу обучить свой GAN / дискриминатор сначала партию реальных изображений, а затем партию поддельных изображений.Время от времени я хочу сравнивать раскрашенную, полутоновую и основанную на правде версию изображений.Поэтому мне нужно, чтобы партии реальных / серых изображений были разделены таким же образом.Я использую Pytorch.Глядя на код, который я включил, они должны давать одинаковые пакеты.Однако они этого не делают.

Я пробовал без работника_init_fn.Я также пробовал разные случайные вызовы функций и передавал их в worker_init_fn безрезультатно.

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                          shuffle=True, num_workers=workers, worker_init_fn = random.seed(seed))

dataloader_gray = torch.utils.data.DataLoader(dataset_gray, batch_size=batch_size,
                                          shuffle=True, num_workers=workers, worker_init_fn = random.seed(seed))

for i, (data, data_gray) in enumerate(zip(dataloader, dataloader_gray)):
    doStuff()

1 Ответ

0 голосов
/ 22 апреля 2019

Как указал Харан Раджкумар , гораздо лучшее решение будет включать в себя конкатенацию обоих наборов данных заранее и применение torch.utils.DataLoader после этого (при условии, что оба объекта torch.utils.Dataset содержат изображения в том же порядке в начале ).

Обратите внимание, что для выполнения этой операции не нужно создавать отдельный класс, torch.utils.data.ConcatDataset предоставляет эту функциональность "из коробки".

Не уверен насчет вашего точного кода, но этого должно быть достаточно (или, по крайней мере, достаточно, чтобы направить вас в нужном направлении):

import torch

dataloader = torch.utils.data.DataLoader(
    torch.utils.data.ConcatDataset(dataset, dataset_gray),
    batch_size=batch_size,
    shuffle=True,
    num_workers=workers
)

for i, (data, data_gray) in enumerate(dataloader):
    doStuff()

Как видите, он намного удобнее для чтения и использования.

...