Я пытаюсь реализовать WGAN с GP в TensorFlow 2.0. Чтобы рассчитать штраф за градиент, необходимо рассчитать градиенты прогнозов относительно входных изображений.
Теперь, чтобы сделать его немного более удобным, вместо того, чтобы вычислять градиенты предсказаний относительно всех входных изображений, он вычисляет интерполированные точки данных вдоль линий исходных и поддельных точек данных и использует их как входы.
Чтобы реализовать это, я сначала разрабатываю функцию compute_gradients
, которая будет принимать некоторые прогнозы и возвращать их градиенты относительно некоторых входных изображений. Сначала я подумал сделать это с tf.keras.backend.gradients
, но он не будет работать в активном режиме. Итак, я пытаюсь сделать это сейчас, используя GradientTape
.
Вот код, который я использую для проверки:
from tensorflow.keras import backend as K
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
import tensorflow as tf
import numpy as np
# Comes from Generative Deep Learning by David Foster
class RandomWeightedAverage(tf.keras.layers.Layer):
def __init__(self, batch_size):
super().__init__()
self.batch_size = batch_size
"""Provides a (random) weighted average between real and generated image samples"""
def call(self, inputs):
alpha = K.random_uniform((self.batch_size, 1, 1, 1))
return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])
# Dummy critic
def make_critic():
critic = Sequential()
inputShape = (28, 28, 1)
critic.add(Conv2D(32, (5, 5), padding="same", strides=(2, 2),
input_shape=inputShape))
critic.add(LeakyReLU(alpha=0.2))
critic.add(Conv2D(64, (5, 5), padding="same", strides=(2, 2)))
critic.add(LeakyReLU(alpha=0.2))
critic.add(Flatten())
critic.add(Dense(512))
critic.add(LeakyReLU(alpha=0.2))
critic.add(Dropout(0.3))
critic.add(Dense(1))
return critic
# Gather dataset
((X_train, _), (X_test, _)) = tf.keras.datasets.fashion_mnist.load_data()
X_train = X_train.reshape(-1, 28, 28, 1)
X_test = X_test.reshape(-1, 28, 28, 1)
# Note that I am using test images as fake images for testing purposes
interpolated_img = RandomWeightedAverage(32)([X_train[0:32].astype("float"), X_test[32:64].astype("float")])
# Compute gradients of the predictions with respect to the interpolated images
critic = make_critic()
with tf.GradientTape() as tape:
y_pred = critic(interpolated_img)
gradients = tape.gradient(y_pred, interpolated_img)
Градиенты становятся None
. Я что-то здесь упускаю?