Ошибка размерности в плотной сети для cGAN - PullRequest
0 голосов
/ 15 февраля 2020

Я тренирую условный ГАН. Для тех, кто не знает, y - метка, а z - некоторый шум, используемый в качестве начального числа. Мои изображения из набора данных были изменены до 32x32 и RGB. (Всего 2750 изображений в 275 классах)

Моя модель выглядит примерно так: Первая - это мой генератор

    self.dense_z = tf.keras.layers.Dense(256, activation='relu')
    self.dropout_z = tf.keras.layers.Dropout(0.5)
    self.dense_y = tf.keras.layers.Dense(256, activation='relu')
    self.dropout_y = tf.keras.layers.Dropout(0.5)
    self.combined_dense = tf.keras.layers.Dense(512, activation='relu')
    self.dropout_x = tf.keras.layers.Dropout(0.5)
    self.final_dense = tf.keras.layers.Dense(32 * 32 * self.channels, activation='tanh')
    self.reshape = tf.keras.layers.Reshape((32, 32, self.channels))

Мой дискриминатор выглядит так:

        self.flatten = tf.keras.layers.Flatten()
        self.maxout_z = MaxoutDense(240, k=5, activation='relu', drop_prob=0.5)
        self.maxout_y = MaxoutDense(50, k=5, activation='relu', drop_prob=0.5)
        self.maxout_x = MaxoutDense(240, k=4, activation='relu', drop_prob=0.5)
        self.out = tf.keras.layers.Dense(1)

Текущая ошибка:

Input 0 of layer dense_129 is incompatible with the layer: : expected min_ndim=2, found ndim=1. Full shape received: [10]Input 0 of layer dense_129 is incompatible with the layer: : expected min_ndim=2, found ndim=1. Full shape received: [10]

Я предполагаю, что полученная полная форма может ссылаться на количество классов, но я не совсем уверен. Каждый раз, когда я запускаю этот код, значение layer_dense увеличивается. Таким образом, у меня нет способа изолировать источник моей ошибки

Шаги обучения выглядят так:

while i<2200:
          images= train_dataset[i]
          labels=y_train[i]
          gen_loss, disc_loss = train_step(images, labels)
          total_gen_loss += gen_loss
          total_disc_loss += disc_loss
          i=i+1
        print('Time for epoch {} is {} sec - gen_loss = {}, disc_loss = {}'.format(epoch, time.time() - start, total_gen_loss / batch_size, total_disc_loss / batch_size))
        if epoch % save_interval == 0:
            save_imgs(epoch, generator, seed)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...