Я работаю над примером SRGAN, используя Keras. Кажется, я столкнулся с ошибкой во время функции train_on_batch, касающейся потерь, ожидающих, что цели будут иметь ту же форму, что и выходные данные.
ValueError: Целевой массив с формой (1, 1) был передан для вывода формы (Нет, 16, 16, 1) при использовании в качестве потерь mean_squared_error
. Эта потеря предполагает, что цели будут иметь ту же форму, что и результат.
- Ubuntu 18.04
- Python 3.7.6 (Conda)
- Tensorflow 2.1
- Keras 2.2.4
Я все еще учусь и немного не понимаю, как возникла проблема. Любая помощь будет принята с благодарностью.
import glob
import time
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import scipy
from tensorflow.keras import Input
from tensorflow.keras.applications import VGG19
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import BatchNormalization, Activation, LeakyReLU, Add, Dense
from tensorflow.keras.layers import Conv2D, UpSampling2D
from imageio import imread
from skimage.transform import resize
data_dir = "/mnt/vanguard/datasets/ffhq-dataset/resized256/*.*"
PATH = "/mnt/vanguard/lab/srgan/"
OUTDIR = PATH + "scripts/"
batch_size = 1
hires_shape=(256, 256, 3)
lowres_shape=(64, 64, 3)
# Common optimizer for all networks
common_optimizer = Adam(0.0002, 0.5)
# Load and augment dataset
def load_images(data_dir, batch_size, hires_shape, lowres_shape):
# Make a list of all images in side the data directory.
all_images = glob.glob(data_dir)
# Choose a random batch of images
images_batch = np.random.choice(all_images, size=batch_size)
hires_images = []
lowres_images = []
for image in images_batch:
# Get an ndarray of the current image
img = imread(image)
img = img.astype(np.float32)
# Resize images
img_hires = resize(img, hires_shape)
img_lowres = resize(img, lowres_shape)
# Do a random flip
if np.random.random() < 0.5:
img_hires = np.fliplr(img_hires)
img_lowres = np.fliplr(img_lowres)
hires_images.append(img_hires)
lowres_images.append(img_lowres)
return np.array(hires_images), np.array(lowres_images)
# Create residual block
def residual_block(x):
filters = [64, 64]
kernel_size = 3
strides = 1
padding = "same"
momentum = 0.8
activation = "relu"
res = Conv2D(filters=filters[0], kernel_size=kernel_size, strides=strides, padding=padding)(x)
res = Activation(activation=activation)(res)
res = BatchNormalization(momentum=momentum)(res)
res = Conv2D(filters=filters[1], kernel_size=kernel_size, strides=strides, padding=padding)(res)
res = BatchNormalization(momentum=momentum)(res)
# Add res and x
res = Add()([res, x])
return res
# Create generator model
def build_generator():
residual_blocks = 16
momentum = 0.8
input_shape = (64, 64, 3)
# Input Layer of the generator network
input_layer = Input(shape=input_shape)
# Add the pre-residual block
gen1 = Conv2D(filters=64, kernel_size=9, strides=1, padding='same', activation='relu')(input_layer)
# Add 16 residual blocks
res = residual_block(gen1)
for i in range(residual_blocks - 1):
res = residual_block(res)
# Add the post-residual block
gen2 = Conv2D(filters=64, kernel_size=3, strides=1, padding='same')(res)
gen2 = BatchNormalization(momentum=momentum)(gen2)
# Take the sum of the output from the pre-residual block(gen1) and the post-residual block(gen2)
gen3 = Add()([gen2, gen1])
# Add an upsampling block
gen4 = UpSampling2D(size=2)(gen3)
gen4 = Conv2D(filters=256, kernel_size=3, strides=1, padding='same')(gen4)
gen4 = Activation('relu')(gen4)
# Add another upsampling block
gen5 = UpSampling2D(size=2)(gen4)
gen5 = Conv2D(filters=256, kernel_size=3, strides=1, padding='same')(gen5)
gen5 = Activation('relu')(gen5)
# Output convolution layer
gen6 = Conv2D(filters=3, kernel_size=9, strides=1, padding='same')(gen5)
output = Activation('tanh')(gen6)
# Keras model
model = Model(inputs=[input_layer], outputs=[output], name='generator')
return model
# Create a discriminator
def build_discriminator():
leakyrelu_alpha = 0.2
momentum = 0.8
input_shape = (256, 256, 3)
input_layer = Input(shape=input_shape)
# Add the first convolution block
dis1 = Conv2D(filters=64, kernel_size=3, strides=1, padding='same')(input_layer)
dis1 = LeakyReLU(alpha=leakyrelu_alpha)(dis1)
# Add the 2nd convolution block
dis2 = Conv2D(filters=64, kernel_size=3, strides=2, padding='same')(dis1)
dis2 = LeakyReLU(alpha=leakyrelu_alpha)(dis2)
dis2 = BatchNormalization(momentum=momentum)(dis2)
# Add the third convolution block
dis3 = Conv2D(filters=128, kernel_size=3, strides=1, padding='same')(dis2)
dis3 = LeakyReLU(alpha=leakyrelu_alpha)(dis3)
dis3 = BatchNormalization(momentum=momentum)(dis3)
# Add the fourth convolution block
dis4 = Conv2D(filters=128, kernel_size=3, strides=2, padding='same')(dis3)
dis4 = LeakyReLU(alpha=leakyrelu_alpha)(dis4)
dis4 = BatchNormalization(momentum=0.8)(dis4)
# Add the fifth convolution block
dis5 = Conv2D(256, kernel_size=3, strides=1, padding='same')(dis4)
dis5 = LeakyReLU(alpha=leakyrelu_alpha)(dis5)
dis5 = BatchNormalization(momentum=momentum)(dis5)
# Add the sixth convolution block
dis6 = Conv2D(filters=256, kernel_size=3, strides=2, padding='same')(dis5)
dis6 = LeakyReLU(alpha=leakyrelu_alpha)(dis6)
dis6 = BatchNormalization(momentum=momentum)(dis6)
# Add the seventh convolution block
dis7 = Conv2D(filters=512, kernel_size=3, strides=1, padding='same')(dis6)
dis7 = LeakyReLU(alpha=leakyrelu_alpha)(dis7)
dis7 = BatchNormalization(momentum=momentum)(dis7)
# Add the eight convolution block
dis8 = Conv2D(filters=512, kernel_size=3, strides=2, padding='same')(dis7)
dis8 = LeakyReLU(alpha=leakyrelu_alpha)(dis8)
dis8 = BatchNormalization(momentum=momentum)(dis8)
# Add a dense layer
dis9 = Dense(units=1024)(dis8)
dis9 = LeakyReLU(alpha=0.2)(dis9)
# Last dense layer - for classification
output = Dense(units=1, activation='sigmoid')(dis9)
model = Model(inputs=[input_layer], outputs=[output], name='discriminator')
return model
def build_vgg():
input_shape = (256, 256, 3)
# Load pre-trained VGG19 model trained on 'Imagenet' dataset
vgg = VGG19(weights="imagenet", include_top=False, input_shape=input_shape)
vgg.outputs=[vgg.layers[9].output]
input_layer=Input(shape=input_shape)
# Extract features
features=vgg(input_layer)
# Create keras model
model = Model(inputs=[input_layer], outputs=[features])
return model
# Build and compile VGG19 network to extract features
vgg = build_vgg()
vgg.trainable = False
vgg.compile(loss='mse', optimizer=common_optimizer, metrics=['accuracy'])
# Build and compile the discriminator
discriminator = build_discriminator()
discriminator.compile(loss='mse', optimizer=common_optimizer, metrics=['accuracy'])
# Build the generator network
generator = build_generator()
# High res. and low res. images
input_hires = Input(shape=hires_shape)
input_lowres = Input(shape=lowres_shape)
# Generate high-resolution images from low-resolution images
gen_hires_images = generator(input_lowres)
# Extract feature maps of the generated images
features = vgg(gen_hires_images)
# Make the discriminator network as trainable false
discriminator.trainable = False
# Get the probability of generated high-resolution images
probs = discriminator(gen_hires_images)
# Create and compile an adversarial model combining
adversarial_model = Model([input_lowres, input_hires], [probs, features])
adversarial_model.compile(loss=['binary_crossentropy', 'mse'], loss_weights=[1e-3, 1], optimizer=common_optimizer)
logdir = OUTDIR + "logs/"
epochs = 1
for epoch in range(epochs):
print("Epoch:{}".format(epoch))
# Sample a batch of images
hires_images, lowres_images = load_images(data_dir=data_dir, batch_size=batch_size, lowres_shape=lowres_shape, hires_shape=hires_shape)
# Normalize images
hires_images = hires_images / 127.5 - 1.
lowres_images = lowres_images / 127.5 - 1.
# Generate high resolution images
gen_hires_images = generator.predict(lowres_images)
real_labels = np.ones((batch_size, 1))
fake_labels = np.zeros((batch_size, 1))
# Train the discriminator network on real and fake images
d_loss_real = discriminator.train_on_batch(hires_images, real_labels)
d_loss_fake = discriminator.train_on_batch(gen_hires_images, fake_labels)