Как реализовать выполнение графа для пакетного обучения в тензорном потоке - PullRequest
0 голосов
/ 10 октября 2019

Я построил свою модель GAN с использованием Keras и обучил ее на своем графическом процессоре, но мой графический процессор использовался только на 15%. Как я могу оптимизировать этот код, используя преимущества Tensorflow 2.0?

  for e in range(1, epochs+1):
  print ('-'*15, 'Epoch %d' % e, '-'*15)
  for _ in tqdm(range(batch_count)):

      rand_nums = np.random.randint(0, x_train_hr.shape[0], size=batch_size)

      image_batch_hr = x_train_hr[rand_nums]
      image_batch_lr = x_train_lr[rand_nums]
      generated_images_sr = generator.predict(image_batch_lr)

      real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2
      fake_data_Y = np.random.random_sample(batch_size)*0.2

      discriminator.trainable = True
      #Update Discriminator
      d_loss_real = discriminator.train_on_batch(image_batch_hr, real_data_Y)
      d_loss_fake = discriminator.train_on_batch(generated_images_sr, fake_data_Y)
      discriminator_loss = 0.5 * np.add(d_loss_fake, d_loss_real)

      rand_nums = np.random.randint(0, x_train_hr.shape[0], size=batch_size)
      image_batch_hr = x_train_hr[rand_nums]
      image_batch_lr = x_train_lr[rand_nums]

      gan_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2
      #Update Generator
      discriminator.trainable = False
      gan_loss = gan.train_on_batch(image_batch_lr, [image_batch_hr,gan_Y])
...