Я пытаюсь построить модель передачи нейронного стиля в реальном времени, следуя архитектуре, описанной в этой статье , где обучаемая модель обрабатывает входное изображение, а предварительно обученная модель (vgg19)используется для расчета как стиля, так и потери контента. Цель состоит в том, чтобы обучить модель, чтобы выходные данные могли иметь содержимое целевого изображения со стилем изображения целевого стиля.
Проблема возникает при использовании tf.GradientTape () для вычисления градиентов:
with tf.GradientTape() as tape:
output_image = model(input_image, training=True)
losses = compute_loss(input_image,...)
grads = tape.gradient(losses, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
Внутри 'compute_loss ()' используется модель vgg19, а также другие функции для расчета потери содержимого и стиля. «Модель» - это тот, кого нужно обучать. На каждой итерации я распечатываю потери, и они не прогрессируют, а полученные градиенты (грады) - все нули.
Используемая «модель» показана ниже, которая в основном заимствована из Этот GitHubрепо.
class MyInstanceNorm(keras.layers.Layer):
def build(self, batch_input_shape):
self.scale = self.add_weight(name='scale', shape=[batch_input_shape[-1]],
initializer='ones',dtype=tf.float32)
self.shift = self.add_weight(name='shift', shape=[batch_input_shape[-1]],
initializer='zeros', dtype=tf.float32)
super().build(batch_input_shape)
def call(self, X, training=True):
if training:
mean, variance = tf.nn.moments(X, axes=[1,2], keepdims=True)
std = tf.sqrt(variance)
epsilon = 1e-3
X_ = (X - mean) / (std + epsilon)
return self.scale * X + self.shift
else:
return X
def conv_layer(net, filters, kernel_size, strides, padding='SAME', relu=True, transpose=False, input_shape=None):
if not transpose:
if input_shape:
x = keras.layers.Conv2D(filters=filters, kernel_size=kernel_size,
strides=strides, padding=padding,
input_shape=input_shape,
kernel_initializer=keras.initializers.TruncatedNormal(0, 1, 1))(net)
else:
x = keras.layers.Conv2D(filters=filters, kernel_size=kernel_size,
strides=strides, padding=padding,
kernel_initializer=keras.initializers.TruncatedNormal(0, 1, 1))(net)
else:
x = keras.layers.Conv2DTranspose(filters=filters, kernel_size=kernel_size,
strides=strides, padding=padding,
kernel_initializer=keras.initializers.TruncatedNormal(0, 1, 1))(net)
x = MyInstanceNorm()(x)
if relu:
x = keras.activations.relu(x)
return x
def residual_block(net):
tmp1 = conv_layer(net, 128, 3, 1)
return net + conv_layer(tmp1, 128, 3, 1, relu=False)
def NST_model(init_image):
conv1 = conv_layer(init_image, 32, 9, 1)
conv2 = conv_layer(conv1, 64, 3, 2)
conv3 = conv_layer(conv2, 128, 3, 2)
resid1 = residual_block(conv3)
resid2 = residual_block(resid1)
resid3 = residual_block(resid2)
resid4 = residual_block(resid3)
resid5 = residual_block(resid4)
conv_t1 = conv_layer(resid5, 64, 3, 2, transpose=True)
conv_t2 = conv_layer(conv_t1, 32, 3, 2, transpose=True)
conv_t3 = conv_layer(conv_t2, 3, 9, 1, relu=False, transpose=True)
out = keras.activations.tanh(conv_t3) * 150 + 255./2
return out
# Building the actual model
input_image = keras.layers.Input(shape=[288, 288, 3])
output_image = NST_model(input_image)
ITN_model = keras.Model(input_image, output_image)
Когда дело доходит до 'compute_loss', оно показывается здесь (с Здесь ):
def compute_loss(loss_model, init_image, loss_weights, gram_style_features, content_features):
style_weight, content_weight = loss_weights
model_outputs = loss_model(init_image)
style_output_features = model_outputs[:num_style_layers]
content_output_features = model_outputs[num_style_layers:]
style_loss = 0
content_loss = 0
weight_per_style_layer = 1.0 / float(num_style_layers)
weight_per_content_layer = 1.0 / float(num_content_layers)
for target_style, image_style in zip(gram_style_features, style_output_features):
style_loss += weight_per_style_layer * get_style_loss(target_style, image_style[0])
for target_content, image_content in zip(content_features, content_output_features):
content_loss += weight_per_content_layer * get_content_loss(target_content, image_content[0])
style_loss *= style_weight
content_loss *= content_weight
loss = style_loss + content_loss
return loss, style_loss, content_loss
Есть идеи, что может быть причиной? Заранее спасибо.