Как вычислить градиенты с помощью tf.scatter_sub? - PullRequest
1 голос
/ 16 января 2020

При реализации lambda-opt (алгоритм, опубликованный на KDD'19) в тензорном потоке я столкнулся с проблемой вычисления градиентов с помощью tf.scatter_sub

θ, относящихся к вложению матрица для docid. Формулировка:

θ (t + 1) = θ (t) - α * (град + 2 * λ * θ),

delta = theta_grad_no_reg.values * lr + 2 * lr * cur_scale * cur_theta
next_theta_tensor = tf.scatter_sub(theta,theta_grad_no_reg.indices,delta)

, затем я использую θ (t + 1 ) для некоторых вычислений. Наконец, я хочу вычислить градиенты относительно λ, а не θ.

Но градиент равен None.

Я написал демо-версию примерно так:

import tensorflow as tf

w = tf.constant([[1.0], [2.0], [3.0]], dtype=tf.float32)
y = tf.constant([5.0], dtype=tf.float32)

# θ
emb_matrix = tf.get_variable("embedding_name", shape=(10, 3),
                    initializer=tf.random_normal_initializer(),dtype=tf.float32)
# get one line emb
cur_emb=tf.nn.embedding_lookup(emb_matrix,[0])
# The λ matrix
doc_lambda = tf.get_variable(name='docid_lambda', shape=(10, 3),
                             initializer=tf.random_normal_initializer(), dtype=tf.float32)
# get one line λ
cur_lambda=tf.nn.embedding_lookup(doc_lambda, [0])

# θ(t+1) Tensor("ScatterSub:0", shape=(10, 3), dtype=float32_ref)
next_emb_matrix=tf.scatter_sub(emb_matrix, [0], (cur_emb *cur_lambda)) 
# do some compute with θ(t+1) Tensor ,not Variable
next_cur_emb=tf.nn.embedding_lookup(next_emb_matrix,[0])

y_ = tf.matmul(next_cur_emb, w)
loss = tf.reduce_mean((y - y_) ** 2)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
grad_var_list=optimizer.compute_gradients(loss)
print(grad_var_list)
# [(None, <tf.Variable 'embedding_name:0' shape=(10, 3) dtype=float32_ref>), (None, <tf.Variable 'docid_lambda:0' shape=(10, 3) dtype=float32_ref>)]

градиента тоже нет. Кажется, что tf.scatter_sub op не обеспечивает градиент?

Спасибо за вашу помощь!

Если у вас есть интерес к этому алгоритму, вы можете найти его , но это не важно в этом вопросе.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...