Так определяется tf.GradientTape().gradient()
. Он имеет ту же функциональность, что и tf.gradients()
, за исключением того, что последний нельзя использовать в активном режиме. Из документов из tf.gradients()
:
Возвращается список тензоров длины len(xs)
, где каждый тензор является sum(dy/dx) for y in ys
где xs
равны sources
и ys
равны target
.
Пример 1 :
Итак, скажем target = [y1, y2]
и sources = [x1, x2]
. Результат будет:
[dy1/dx1 + dy2/dx1, dy1/dx2 + dy2/dx2]
Пример 2 :
Вычислить градиенты для потерь на выборку (тензор) против уменьшенных потерь (скаляр)
Let w, b be two variables.
xentropy = [y1, y2] # tensor
reduced_xentropy = 0.5 * (y1 + y2) # scalar
grads = [dy1/dw + dy2/dw, dy1/db + dy2/db]
reduced_grads = [d(reduced_xentropy)/dw, d(reduced_xentropy)/db]
= [d(0.5 * (y1 + y2))/dw, d(0.5 * (y1 + y2))/db]
== 0.5 * grads
Пример Tensorflow приведенного выше фрагмента:
import tensorflow as tf
print(tf.__version__) # 2.1.0
inputs = tf.convert_to_tensor([[0.1, 0], [0.5, 0.51]]) # two two-dimensional samples
w = tf.Variable(initial_value=inputs)
b = tf.Variable(tf.zeros((2,)))
labels = tf.convert_to_tensor([0, 1])
def forward(inputs, labels, var_list):
w, b = var_list
logits = tf.matmul(inputs, w) + b
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits)
return xentropy
# `xentropy` has two elements (gradients of tensor - gradient
# of each sample in a batch)
with tf.GradientTape() as g:
xentropy = forward(inputs, labels, [w, b])
reduced_xentropy = tf.reduce_mean(xentropy)
grads = g.gradient(xentropy, [w, b])
print(xentropy.numpy()) # [0.6881597 0.71584916]
print(grads[0].numpy()) # [[ 0.20586157 -0.20586154]
# [ 0.2607238 -0.26072377]]
# `reduced_xentropy` is scalar (gradients of scalar)
with tf.GradientTape() as g:
xentropy = forward(inputs, labels, [w, b])
reduced_xentropy = tf.reduce_mean(xentropy)
grads_reduced = g.gradient(reduced_xentropy, [w, b])
print(reduced_xentropy.numpy()) # 0.70200443 <-- scalar
print(grads_reduced[0].numpy()) # [[ 0.10293078 -0.10293077]
# [ 0.1303619 -0.13036188]]
Если вы вычисляете потери (xentropy
) для каждого элемента в пакете, окончательные градиенты каждой переменной будут суммой всех градиентов для каждый образец в партии (что имеет смысл).