Спасибо за ваш ответ, @Ioannis Nasios. Да, мои веса в «gdrive». Я тренирую сеть GAN и пытаюсь понять, как загрузить эти веса и продолжить обучение. Я сохранил веса дискриминатора и генератора, а также gan_loss иcriminator_loss. Ну, я должен компилировать сети генератора и дискриминатора, загружать веса и компилировать сеть GAN с их потерей? Я думаю, что это может быть глупый вопрос. Я впервые тренируюсь в сети GAN. Здесь я выкладываю код:
# Combined network
def get_gan_network(discriminator, shape, generator, optimizer, loss):
discriminator.trainable = False
gan_input = Input(shape=shape)
x = generator(gan_input)
gan_output = discriminator(x)
gan = Model(inputs=gan_input, outputs=[x,gan_output])
gan.compile(loss=[loss, "binary_crossentropy"],
loss_weights=[1., 1e-3],
optimizer=optimizer)
return gan
def train(x_train_lr, x_train_hr, x_test_lr, x_test_hr, epochs, batch_size, output_dir, model_save_dir, weights_save_dir):
loss = VGG_LOSS(image_shape)
batch_count = int(x_train_hr.shape[0] / batch_size)
#### SI LAS IMAGENES NO SON CUADRADAS ESTO DEBERIA CAMBIAR
shape_lr = (image_shape[0]//downscale_factor, image_shape[1]//downscale_factor, image_shape[2])
shape_hr = x_train_hr[0].shape
####
generator = Generator(shape_lr, shape_hr).generator()
discriminator = Discriminator(image_shape).discriminator()
optimizer = Utils_model.get_optimizer()
generator.compile(loss=loss.vgg_loss, optimizer=optimizer)
discriminator.compile(loss="binary_crossentropy", optimizer=optimizer)
gan = get_gan_network(discriminator, shape_lr, generator, optimizer, loss.vgg_loss)
loss_file = open(model_save_dir + '/losses.txt' , 'w+')
loss_file.close()
for e in range(1, epochs+1):
print ('-'*15, 'Epoch %d' % e, '-'*15)
for _ in tqdm(range(batch_count)):
rand_nums = np.random.randint(0, x_train_hr.shape[0], size=batch_size)
image_batch_hr = x_train_hr[rand_nums]
image_batch_lr = x_train_lr[rand_nums]
generated_images_sr = generator.predict(image_batch_lr)
real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2
fake_data_Y = np.random.random_sample(batch_size)*0.2
discriminator.trainable = True
d_loss_real = discriminator.train_on_batch(image_batch_hr, real_data_Y)
d_loss_fake = discriminator.train_on_batch(generated_images_sr, fake_data_Y)
discriminator_loss = 0.5 * np.add(d_loss_fake, d_loss_real)
rand_nums = np.random.randint(0, x_train_hr.shape[0], size=batch_size)
image_batch_hr = x_train_hr[rand_nums]
image_batch_lr = x_train_lr[rand_nums]
gan_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2
discriminator.trainable = False
gan_loss = gan.train_on_batch(image_batch_lr, [image_batch_hr,gan_Y])
print("discriminator_loss : %f" % discriminator_loss)
print("gan_loss :", gan_loss)
gan_loss = str(gan_loss)
loss_file = open(model_save_dir + 'losses.txt' , 'a')
loss_file.write('epoch%d : gan_loss = %s ; discriminator_loss = %f\n' %(e, gan_loss, discriminator_loss) )
loss_file.close()
if e == 1 or e % 5 == 0:
Utils.plot_generated_images(output_dir, e, generator, x_test_hr, x_test_lr)
generator.save_weights(weights_save_dir + '%d_gen_weights.h5' % e)
discriminator.save_weights(weights_save_dir + '%d_dis_weights.h5' % e)
if e % 500 == 0 or e == epochs+1:
generator.save(model_save_dir + 'gen_model%d.h5' % e)
discriminator.save(model_save_dir + 'dis_model%d.h5' % e)