Попытка обучить cGAN с данными mnist, но возникла эта ошибка AttributeError: объект 'ZipDataset' не имеет атрибута 'shape' - PullRequest
0 голосов
/ 30 мая 2020

Это код, который тренирует cGAN один раз (за один шаг)

generator, discriminator = build_generator(), build_discriminator()
generator_optimizer = optimizers.Adam(1e-4)
discriminator_optimizer = optimizers.Adam(1e-4)

noise = tf.data.Dataset.from_tensor_slices(tf.random.normal([BATCH_SIZE, noise_dim]))
#shape (256, 100)
images = tf.data.Dataset.from_tensor_slices(images)
#shape (256, 28, 28)
label = tf.data.Dataset.from_tensor_slices(label)
#shape (256, 10)

with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    #error occured at this line
    generated_images = generator(tf.data.Dataset.zip(((noise, label), )), training=True)
    real_output = discriminator(tf.data.Dataset.zip(((images, label), )), training=True)
    fake_output = discriminator(tf.data.Dataset.zip(((generated_images, label), )), training=True)

    gen_loss = generator_loss(fake_output)
    disc_loss = discriminator_loss(real_output, fake_output)

#evaluate grad 
generator_grad = gen_tape.gradient(gen_loss, generator.trainable_variables)
discriminator_grad = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
#train one step
generator_optimizer.apply_gradients(zip(generator_grad, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(discriminator_grad, discriminator.trainable_variables))

Входные данные сети генератора - шум и метка, каждая форма - (100,) и (10,). Входом сети дискриминатора являются изображение и метка, и каждая форма - это (28, 28) и (10,). Я использовал generator(tf.data.Dataset.zip(((noise, label), )), training=True), потому что вход генератора несколько. Но я получил эту ошибку в этом месте AttributeError: объект 'ZipDataset' не имеет атрибута 'shape'

Я использую тензорный поток 2.1.0. Пожалуйста помоги. Спасибо!

...