SRGAN: Несоответствующая форма целевого массива (1, 1) при передаче для вывода (None, 16, 16, 1) - PullRequest
0 голосов
/ 16 июня 2020

Я работаю над примером SRGAN, используя Keras. Кажется, я столкнулся с ошибкой во время функции train_on_batch, касающейся потерь, ожидающих, что цели будут иметь ту же форму, что и выходные данные.

ValueError: Целевой массив с формой (1, 1) был передан для вывода формы (Нет, 16, 16, 1) при использовании в качестве потерь mean_squared_error. Эта потеря предполагает, что цели будут иметь ту же форму, что и результат.

  • Ubuntu 18.04
  • Python 3.7.6 (Conda)
  • Tensorflow 2.1
  • Keras 2.2.4

Я все еще учусь и немного не понимаю, как возникла проблема. Любая помощь будет принята с благодарностью.

import glob
import time

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import scipy

from tensorflow.keras import Input
from tensorflow.keras.applications import VGG19
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import BatchNormalization, Activation, LeakyReLU, Add, Dense
from tensorflow.keras.layers import Conv2D, UpSampling2D

from imageio import imread
from skimage.transform import resize

data_dir = "/mnt/vanguard/datasets/ffhq-dataset/resized256/*.*"
PATH = "/mnt/vanguard/lab/srgan/"
OUTDIR = PATH + "scripts/"

batch_size = 1

hires_shape=(256, 256, 3)
lowres_shape=(64, 64, 3)

# Common optimizer for all networks
common_optimizer = Adam(0.0002, 0.5)

# Load and augment dataset
def load_images(data_dir, batch_size, hires_shape, lowres_shape):
    # Make a list of all images in side the data directory.
    all_images = glob.glob(data_dir)

    # Choose a random batch of images
    images_batch = np.random.choice(all_images, size=batch_size)

    hires_images = []
    lowres_images = []

    for image in images_batch:
        # Get an ndarray of the current image
        img = imread(image)
        img = img.astype(np.float32)

        # Resize images
        img_hires = resize(img, hires_shape)
        img_lowres = resize(img, lowres_shape)

        # Do a random flip
        if np.random.random() < 0.5:
            img_hires = np.fliplr(img_hires)
            img_lowres = np.fliplr(img_lowres)

        hires_images.append(img_hires)
        lowres_images.append(img_lowres)

    return np.array(hires_images), np.array(lowres_images)

# Create residual block
def residual_block(x):
    filters = [64, 64]
    kernel_size = 3
    strides = 1
    padding = "same"
    momentum = 0.8
    activation = "relu"

    res = Conv2D(filters=filters[0], kernel_size=kernel_size, strides=strides, padding=padding)(x)
    res = Activation(activation=activation)(res)
    res = BatchNormalization(momentum=momentum)(res)

    res = Conv2D(filters=filters[1], kernel_size=kernel_size, strides=strides, padding=padding)(res)
    res = BatchNormalization(momentum=momentum)(res)

    # Add res and x
    res = Add()([res, x])
    return res

# Create generator model
def build_generator():
    residual_blocks = 16
    momentum = 0.8
    input_shape = (64, 64, 3)

    # Input Layer of the generator network
    input_layer = Input(shape=input_shape)

    # Add the pre-residual block
    gen1 = Conv2D(filters=64, kernel_size=9, strides=1, padding='same', activation='relu')(input_layer)

    # Add 16 residual blocks
    res = residual_block(gen1)
    for i in range(residual_blocks - 1):
        res = residual_block(res)

    # Add the post-residual block
    gen2 = Conv2D(filters=64, kernel_size=3, strides=1, padding='same')(res)
    gen2 = BatchNormalization(momentum=momentum)(gen2)

    # Take the sum of the output from the pre-residual block(gen1) and the post-residual block(gen2)
    gen3 = Add()([gen2, gen1])

    # Add an upsampling block
    gen4 = UpSampling2D(size=2)(gen3)
    gen4 = Conv2D(filters=256, kernel_size=3, strides=1, padding='same')(gen4)
    gen4 = Activation('relu')(gen4)

    # Add another upsampling block
    gen5 = UpSampling2D(size=2)(gen4)
    gen5 = Conv2D(filters=256, kernel_size=3, strides=1, padding='same')(gen5)
    gen5 = Activation('relu')(gen5)

    # Output convolution layer
    gen6 = Conv2D(filters=3, kernel_size=9, strides=1, padding='same')(gen5)
    output = Activation('tanh')(gen6)

    # Keras model
    model = Model(inputs=[input_layer], outputs=[output], name='generator')
    return model

# Create a discriminator
def build_discriminator():
    leakyrelu_alpha = 0.2
    momentum = 0.8
    input_shape = (256, 256, 3)

    input_layer = Input(shape=input_shape)

    # Add the first convolution block
    dis1 = Conv2D(filters=64, kernel_size=3, strides=1, padding='same')(input_layer)
    dis1 = LeakyReLU(alpha=leakyrelu_alpha)(dis1)

    # Add the 2nd convolution block
    dis2 = Conv2D(filters=64, kernel_size=3, strides=2, padding='same')(dis1)
    dis2 = LeakyReLU(alpha=leakyrelu_alpha)(dis2)
    dis2 = BatchNormalization(momentum=momentum)(dis2)

    # Add the third convolution block
    dis3 = Conv2D(filters=128, kernel_size=3, strides=1, padding='same')(dis2)
    dis3 = LeakyReLU(alpha=leakyrelu_alpha)(dis3)
    dis3 = BatchNormalization(momentum=momentum)(dis3)

    # Add the fourth convolution block
    dis4 = Conv2D(filters=128, kernel_size=3, strides=2, padding='same')(dis3)
    dis4 = LeakyReLU(alpha=leakyrelu_alpha)(dis4)
    dis4 = BatchNormalization(momentum=0.8)(dis4)

    # Add the fifth convolution block
    dis5 = Conv2D(256, kernel_size=3, strides=1, padding='same')(dis4)
    dis5 = LeakyReLU(alpha=leakyrelu_alpha)(dis5)
    dis5 = BatchNormalization(momentum=momentum)(dis5)

    # Add the sixth convolution block
    dis6 = Conv2D(filters=256, kernel_size=3, strides=2, padding='same')(dis5)
    dis6 = LeakyReLU(alpha=leakyrelu_alpha)(dis6)
    dis6 = BatchNormalization(momentum=momentum)(dis6)

    # Add the seventh convolution block
    dis7 = Conv2D(filters=512, kernel_size=3, strides=1, padding='same')(dis6)
    dis7 = LeakyReLU(alpha=leakyrelu_alpha)(dis7)
    dis7 = BatchNormalization(momentum=momentum)(dis7)

    # Add the eight convolution block
    dis8 = Conv2D(filters=512, kernel_size=3, strides=2, padding='same')(dis7)
    dis8 = LeakyReLU(alpha=leakyrelu_alpha)(dis8)
    dis8 = BatchNormalization(momentum=momentum)(dis8)

    # Add a dense layer
    dis9 = Dense(units=1024)(dis8)
    dis9 = LeakyReLU(alpha=0.2)(dis9)

    # Last dense layer - for classification
    output = Dense(units=1, activation='sigmoid')(dis9)

    model = Model(inputs=[input_layer], outputs=[output], name='discriminator')
    return model

def build_vgg():
    input_shape = (256, 256, 3)

    # Load pre-trained VGG19 model trained on 'Imagenet' dataset
    vgg = VGG19(weights="imagenet", include_top=False, input_shape=input_shape)
    vgg.outputs=[vgg.layers[9].output]

    input_layer=Input(shape=input_shape)

    # Extract features
    features=vgg(input_layer)

    # Create keras model
    model = Model(inputs=[input_layer], outputs=[features])
    return model

# Build and compile VGG19 network to extract features
vgg = build_vgg()
vgg.trainable = False
vgg.compile(loss='mse', optimizer=common_optimizer, metrics=['accuracy'])

# Build and compile the discriminator
discriminator = build_discriminator()
discriminator.compile(loss='mse', optimizer=common_optimizer, metrics=['accuracy'])

# Build the generator network
generator = build_generator()

# High res. and low res. images
input_hires = Input(shape=hires_shape)
input_lowres = Input(shape=lowres_shape)

# Generate high-resolution images from low-resolution images
gen_hires_images = generator(input_lowres)

# Extract feature maps of the generated images
features = vgg(gen_hires_images)

# Make the discriminator network as trainable false
discriminator.trainable = False

# Get the probability of generated high-resolution images
probs = discriminator(gen_hires_images)

# Create and compile an adversarial model combining
adversarial_model = Model([input_lowres, input_hires], [probs, features])
adversarial_model.compile(loss=['binary_crossentropy', 'mse'], loss_weights=[1e-3, 1], optimizer=common_optimizer)

logdir = OUTDIR + "logs/"

epochs = 1

for epoch in range(epochs):
    print("Epoch:{}".format(epoch))

# Sample a batch of images
hires_images, lowres_images = load_images(data_dir=data_dir, batch_size=batch_size, lowres_shape=lowres_shape, hires_shape=hires_shape)

# Normalize images
hires_images = hires_images / 127.5 - 1.
lowres_images = lowres_images / 127.5 - 1.

# Generate high resolution images
gen_hires_images = generator.predict(lowres_images)

real_labels = np.ones((batch_size, 1))
fake_labels = np.zeros((batch_size, 1))

# Train the discriminator network on real and fake images
d_loss_real = discriminator.train_on_batch(hires_images, real_labels)
d_loss_fake = discriminator.train_on_batch(gen_hires_images, fake_labels)
...