Я пытался реализовать DCGAN в pytorch. Но во время обучения одна итерация цикла обучения занимает более 7-8 минут на GPU на Google Collab. Я не могу понять, что не так в коде. Я пробовал много методов, чтобы преодолеть эту проблему, но, похоже, ничего не работает ..
Вот мой тренировочный цикл, и он занимал более 7-8 минут за одну итерацию:
device = torch.device("cuda:0")
dis = Discriminator().to(device)
Gen = Generator().to(device)
GAN_loss = nn.BCELoss().to(device)
D_optimizer = optim.Adam(dis.parameters(), lr = 0.0002, betas = (0.5, 0.999))
G_optimizer = optim.Adam(Gen.parameters(), lr = 0.0002, betas = (0.5, 0.999))
path = 'gdrive/My Drive/New_data/'
path2 = 'gdrive/My Drive/New_cropped/'
train_data_list = os.listdir(path)
train_data_len = len(train_data_list)
minibatch_size = 64
epochs = 10
G_losses = []
D_losses = []
final_itr = (train_data_len + minibatch_size - 1) // minibatch_size
data_list = [train_data_list[i * minibatch_size : (i + 1) * minibatch_size] for i in range(final_itr)]
for epoch in range(epochs):
for count, data in enumerate(data_list):
train_img = []
sample_img = []
for image in data:
img_train = cv2.imread(path + image).T/255
img_train = img_train.reshape(1, img_train.shape[0], img_train.shape[1], img_train.shape[2])
img_sample = cv2.imread(path2 + image,0).T/255
img_sample = img_sample.reshape(1, 1, img_sample.shape[0], img_sample.shape[1])
train_img.append(img_train)
sample_img.append(img_sample)
assert(img_sample.shape == (1, 1, 144, 144))
train_image = Variable(torch.from_numpy(np.concatenate(train_img, axis = 0)).cuda())
sample_image = Variable(torch.from_numpy(np.concatenate(sample_img, axis = 0)).cuda())
label = torch.full((train_image.shape[0],), real_label, device=device)
#Training the discriminator... minimizing -(log(D(x)) - log(1 - D(G(Z))))
dis.zero_grad()
Gen.zero_grad()
G_z = Gen(sample_image.detach())
disc_real_out = dis(train_image.detach()).view(-1)
error_real = GAN_loss(disc_real_out, label)
error_real.backward()
disc_fake_out = dis(G_z.detach()).view(-1)
label.fill_(fake_label)
error_fake = GAN_loss(disc_fake_out, label)
error_fake.backward()
total_disc_error = error_real + error_fake
D_optimizer.step()
#Training the Generator... maximizing log(D(G(Z))))
D_G_z = dis(G_z.detach()).view(-1)
label.fill_(real_label)
error_gen = GAN_loss(D_G_z, label)
error_gen.backward()
G_optimizer.step()
G_losses.append(error_gen.item())
D_losses.append(total_disc_error.item())
print("Discriminator Loss : ", total_disc_error.item(), "\t", "Generator Loss : ", error_gen.item())