Использование GradientTape для вычисления градиентов предсказаний относительно некоторых тензоров - PullRequest
0 голосов
/ 19 марта 2020

Я пытаюсь реализовать WGAN с GP в TensorFlow 2.0. Чтобы рассчитать штраф за градиент, необходимо рассчитать градиенты прогнозов относительно входных изображений.

Теперь, чтобы сделать его немного более удобным, вместо того, чтобы вычислять градиенты предсказаний относительно всех входных изображений, он вычисляет интерполированные точки данных вдоль линий исходных и поддельных точек данных и использует их как входы.

Чтобы реализовать это, я сначала разрабатываю функцию compute_gradients, которая будет принимать некоторые прогнозы и возвращать их градиенты относительно некоторых входных изображений. Сначала я подумал сделать это с tf.keras.backend.gradients, но он не будет работать в активном режиме. Итак, я пытаюсь сделать это сейчас, используя GradientTape.

Вот код, который я использую для проверки:

from tensorflow.keras import backend as K
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
import tensorflow as tf
import numpy as np

# Comes from Generative Deep Learning by David Foster
class RandomWeightedAverage(tf.keras.layers.Layer):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size
    """Provides a (random) weighted average between real and generated image samples"""
    def call(self, inputs):
        alpha = K.random_uniform((self.batch_size, 1, 1, 1))
        return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])

# Dummy critic
def make_critic():
    critic = Sequential()
    inputShape = (28, 28, 1)

    critic.add(Conv2D(32, (5, 5), padding="same", strides=(2, 2),
        input_shape=inputShape))
    critic.add(LeakyReLU(alpha=0.2))

    critic.add(Conv2D(64, (5, 5), padding="same", strides=(2, 2)))
    critic.add(LeakyReLU(alpha=0.2))

    critic.add(Flatten())
    critic.add(Dense(512))
    critic.add(LeakyReLU(alpha=0.2))
    critic.add(Dropout(0.3))
    critic.add(Dense(1))

    return critic

# Gather dataset
((X_train, _), (X_test, _)) = tf.keras.datasets.fashion_mnist.load_data()
X_train = X_train.reshape(-1, 28, 28, 1)
X_test = X_test.reshape(-1, 28, 28, 1)

# Note that I am using test images as fake images for testing purposes
interpolated_img = RandomWeightedAverage(32)([X_train[0:32].astype("float"), X_test[32:64].astype("float")])

# Compute gradients of the predictions with respect to the interpolated images
critic = make_critic()
with tf.GradientTape() as tape:
    y_pred = critic(interpolated_img)
gradients = tape.gradient(y_pred, interpolated_img)

Градиенты становятся None. Я что-то здесь упускаю?

1 Ответ

1 голос
/ 20 марта 2020

Градиенты предсказаний относительно некоторых тензоров ... Я что-то здесь упускаю?

Да. Вам необходимо tape.watch(interpolated_img):

with tf.GradientTape() as tape:
    tape.watch(interpolated_img)
    y_pred = critic(interpolated_img)

GradientTape, чтобы сохранить промежуточные значения прямого прохода для вычисления градиентов. Обычно вам нужны градиенты переменных WRT. Таким образом, он не сохраняет след вычислений, начиная с тензоров, возможно, для экономии памяти.

Если вы хотите, чтобы градиент WRT был тензором, вам нужно явно указать tape.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...