Ошибка при проверке ввода: ожидалось, что input_2 будет иметь форму (250, 250, 3), но получил массив с формой (200, 200, 3) при обучении GAN - PullRequest
0 голосов
/ 07 августа 2020

Переменная x_train имеет форму (1000, 100, 100, 3), а y_train имеет форму (1000, 250, 250, 3). Ошибка возникает в строке d_loss_real = D.train_on_batch (imgy, valid) . Я пробовал изменить размер ввода D, но когда я меняю его на (200, 200, 3), это дает мне ту же ошибку, но теперь ожидаемая форма и форма ввода поменяны.

epochs = 1
batch_size = 10
sample_interval = 1

with open('X.data', 'rb') as f:
    X_train = np.array(pickle.load(f))/255
with open('Y.data', 'rb') as f:
    Y_train = np.array(pickle.load(f))/255
inputs = tf.keras.Input(shape=(100, 100, 3))
x = Conv2D(8, (10, 10), activation='relu', padding='same')(inputs)
x = UpSampling2D((2.5, 2.5))(x)
x = Conv2D(8, (10, 10), activation='relu', padding='same')(x)
x = Conv2D(3, (10, 10), activation='relu', padding='same')(x)
G = tf.keras.Model(inputs=inputs, outputs=x)

G.summary()

inputs = tf.keras.Input(shape=(250, 250, 3))
x = Conv2D(8, (10, 10), activation='relu', padding='same')(inputs)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(8, (10, 10), activation='relu', padding='same')(x)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(8, (10, 10), activation='relu', padding='same')(x)
x = MaxPooling2D((4, 4), padding='same')(x)
x = Conv2D(1, (10, 10), activation='relu', padding='same')(x)
x = AveragePooling2D((32, 32), padding='same')(x)

D = tf.keras.Model(inputs=inputs, outputs=x)
D.compile(loss='binary_crossentropy', optimizer='adam')


z =tf.keras. Input(shape=(100, 100, 3))
img = G(z)
D.trainable = False
validity = D(img)
combined = tf.keras.Model(z, validity)
combined.compile(loss='binary_crossentropy', optimizer='adam')

valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

for epoch in range(epochs):
    idx = np.random.randint(0, X_train.shape[0], batch_size)
    imgx = X_train[idx]
    imgy = Y_train[idx]

    g_loss = combined.train_on_batch(imgx, valid)
    
    gen_imgs = G.predict(imgx)
    
    d_loss_real = D.train_on_batch(imgy, valid)
    d_loss_fake = D.train_on_batch(gen_imgs, fake)
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...