После долгих усилий мне удалось построить tensorflow 2
реализацию существующего проекта pytorch
передачи стилей. Затем я хотел получить все приятные дополнительные функции, которые доступны через стандартное обучение Keras, например model.fit()
.
Но та же модель терпит неудачу при обучении через model.fit()
. Модель, кажется, изучает функции контента, но не может изучить функции стиля. Это диаграмма модели в вопросе:
def vgg_layers19(content_layers, style_layers, input_shape=(256,256,3)):
""" creates a VGG model that returns output values for the given layers
see: https://keras.io/applications/#extract-features-from-an-arbitrary-intermediate-layer-with-vgg19
Returns:
function(x, preprocess=True):
Args:
x: image tuple/ndarray h,w,c(RGB), domain=(0.,255.)
Returns:
a tuple of lists, ([content_features], [style_features])
usage:
(content_features, style_features) = vgg_layers16(content_layers, style_layers)(x_train)
"""
preprocessingFn = tf.keras.applications.vgg19.preprocess_input
base_model = tf.keras.applications.VGG19(include_top=False, weights='imagenet', input_shape=input_shape)
base_model.trainable = False
content_features = [base_model.get_layer(name).output for name in content_layers]
style_features = [base_model.get_layer(name).output for name in style_layers]
output_features = content_features + style_features
model = Model( inputs=base_model.input, outputs=output_features, name="vgg_layers")
model.trainable = False
def _get_features(x, preprocess=True):
"""
Args:
x: expecting tensor, domain=255. hwcRGB
"""
if preprocess and callable(preprocessingFn):
x = preprocessingFn(x)
output = model(x) # call as tf.keras.Layer()
return ( output[:len(content_layers)], output[len(content_layers):] )
return _get_features
class VGG_Features():
""" get content and style features from VGG model """
def __init__(self, loss_model, style_image=None, target_style_gram=None):
self.loss_model = loss_model
if style_image is not None:
assert style_image.shape == (256,256,3), "ERROR: loss_model expecting input_shape=(256,256,3), got {}".format(style_image.shape)
self.style_image = style_image
self.target_style_gram = VGG_Features.get_style_gram(self.loss_model, self.style_image)
if target_style_gram is not None:
self.target_style_gram = target_style_gram
@staticmethod
def get_style_gram(vgg_features_model, style_image):
style_batch = tf.repeat( style_image[tf.newaxis,...], repeats=_batch_size, axis=0)
# show([style_image], w=128, domain=(0.,255.) )
# B, H, W, C = style_batch.shape
(_, style_features) = vgg_features_model( style_batch , preprocess=True ) # hwcRGB
target_style_gram = [ fnstf_utils.gram(value) for value in style_features ] # list
return target_style_gram
def __call__(self, input_batch):
content_features, style_features = self.loss_model( input_batch, preprocess=True )
style_gram = tuple(fnstf_utils.gram(value) for value in style_features) # tuple(<generator>)
return (content_features[0],) + style_gram # tuple = tuple + tuple
class TransformerNetwork_VGG(tf.keras.Model):
def __init__(self, transformer=transformer, vgg_features=vgg_features):
super(TransformerNetwork_VGG, self).__init__()
self.transformer = transformer
# type: tf.keras.models.Model
# input_shapes: (None, 256,256,3)
# output_shapes: (None, 256,256,3)
style_model = {
'content_layers':['block5_conv2'],
'style_layers': ['block1_conv1',
'block2_conv1',
'block3_conv1',
'block4_conv1',
'block5_conv1']
}
vgg_model = vgg_layers19( style_model['content_layers'], style_model['style_layers'] )
self.vgg_features = VGG_Features(vgg_model, style_image=style_image, batch_size=batch_size)
# input_shapes: (None, 256,256,3)
# output_shapes: [(None, 16, 16, 512), (None, 64, 64), (None, 128, 128), (None, 256, 256), (None, 512, 512), (None, 512, 512)]
# [ content_loss, style_loss_1, style_loss_2, style_loss_3, style_loss_4, style_loss_5 ]
def call(self, inputs):
x = inputs # shape=(None, 256,256,3)
# shape=(None, 256,256,3)
generated_image = self.transformer(x)
# shape=[(None, 16, 16, 512), (None, 64, 64), (None, 128, 128), (None, 256, 256), (None, 512, 512), (None, 512, 512)]
vgg_feature_losses = self.vgg(generated_image)
return vgg_feature_losses # tuple(content1, style1, style2, style3, style4, style5)
Стиль изображения
FEATURE_WEIGHTS = [1.0, 1,0, 1,0, 1,0, 1,0, 1,0]
GradientTape
обучение
С tf.GradientTape()
l oop я вручную обрабатываю несколько выходов, например, набор из 6 тензоров от TransformerNetwork_VGG(x_train)
. Этот метод обучается правильно.
@tf.function()
def train_step(x_train, y_true, loss_weights=None, log_freq=10):
with tf.GradientTape() as tape:
y_pred = TransformerNetwork_VGG(x_train)
generated_content_features = y_pred[:1]
generated_style_gram = y_pred[1:]
y_true = TransformerNetwork_VGG.vgg(x_train)
target_content_features = y_true[:1]
target_style_gram = TransformerNetwork_VGG.vgg.target_style_gram
content_loss = get_MEAN_mse_loss(target_content_features, generated_content_features, weights)
style_loss = tuple(get_MEAN_mse_loss(x,y)*w for x,y,w in zip(target_style_gram, generated_style_gram, weights))
total_loss = content_loss + = tf.reduce_sum(style_loss)
TransformerNetwork = TransformerNetwork_VGG.transformer
grads = tape.gradient(total_loss, TransformerNetwork.trainable_weights)
optimizer.apply_gradients(zip(grads, TransformerNetwork.trainable_weights))
# GradientTape epoch=5:
# losses: [ 6078.71 70.23 4495.13 13817.65 88217.99 48.36]
model.fit()
learning
С tf.keras.models.Model.fit()
, несколько выходов, например, набор из 6 тензоров , подаются на функцию потерь индивидуально как loss(y_pred, y_true)
, а затем умножаются на правильный вес на reduction
. Этот метод учится приближаться к content_image, но не учится , чтобы минимизировать потери стиля! Я не могу понять, почему.
history = TransformerNetwork_VGG.fit(
x=train_dataset.repeat(NUM_EPOCHS),
epochs=NUM_EPOCHS,
steps_per_epoch=NUM_BATCHES,
callbacks=callbacks,
)
# model.fit() epoch=5:
# losses: [ 4661.08 219.95 6959.01 4897.39 209201.16 84.68]]
50 эпох, с увеличенными style_weights, FEATURE_WEIGHTS = [0.1854, 1605.23, 25.08, 8.16, 1.28, 2330.79] # потеря стиля буста x100
step = 50, потери = [269899.45 337.5 69617.7 38424.96 9192.36 85903.44 66423.51]
проверка mse
потери * веса
Я проверил свою модель с потерями и весами, зафиксированными следующим образом * FEATURE_WEIGHTS = SEQ = [1., 2., 3., 4., 5., 6., ] * MSELoss (y_true, y_pred) == tf.ones () равной формы и подтвердил, что model.fit()
правильно обрабатывает несколько выходных потерь * весов
У меня есть проверил все, что могу придумать, но я не могу понять, как заставить модель учиться правильно с model.fit()
. Чего мне не хватает ??
Полный блокнот доступен здесь: https://colab.research.google.com/github/mixuala/fast_neural_style_pytorch/blob/master/notebook/%5BSO%5D_FastStyleTransfer.ipynb