Снова неагрегированные градиенты в тензорном потоке - PullRequest
0 голосов
/ 20 марта 2019

Мой вопрос касается tf.gradient (ys, xs), который всегда возвращает сумму (dy / dx) для всех y в ys. Суммирование неявно, и, кажется, нет официального способа получить список градиентов для каждого y (например, пакет из 100 примеров).

Я знаю, что здесь есть посты с похожим вопросом, но ни один из них не дает хорошего ответа. Мой вопрос больше связан с решением, данным в этом обсуждении: GitHub Issues . Хотелось бы узнать, пробовал ли кто-нибудь метод, указанный пользователем goodfeli? Вот моя попытка:

with tf.device('/GPU:4'):
    batch_size = 1000
    W = tf.Variable([[1.,2.],[3.,4.],[4.,8.],[3.,6.]])
    x = tf.placeholder(dtype = tf.float32, shape = [None,4])

    y = tf.tensordot(x,W,axes=1)
    loss = tf.reduce_prod(y)
    grad_old = tf.gradients(ys=loss,xs=W)

    ##Non Aggregate

    examples = tf.split(x,num_or_size_splits=batch_size)
    weight_copies = [tf.identity(W) for xs in examples]
    output = tf.stack([tf.tensordot(xs,ws,axes=1) for xs,ws in zip(examples,weight_copies)])
    batch_loss = tf.squeeze(tf.reduce_prod(output,axis=2))
    batch_grad = tf.gradients(ys=batch_loss,xs=weight_copies)
#Run session
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,log_device_placement=False))
tf.global_variables_initializer().run(session=sess)

val = np.random.rand(batch_size,4)
fd = {x:val}
st0 = timeit.default_timer()
grad_b = np.array(sess.run(batch_grad,feed_dict=fd))
print("Batch: %f" %(timeit.default_timer()-st0))

st1 = timeit.default_timer()
grad_s = []
for s in val:
    val_x = [s]
    fds = {x:val_x}
    grad_s.append(sess.run(grad_old,feed_dict=fds)[0])
grad_s = np.array(grad_s)
print("Single: %f" %(timeit.default_timer()-st1))

Однако, когда я запускаю это, время, затрачиваемое на пакетный метод, примерно в 10 раз медленнее, чем на цикл по каждому примеру. Я нахожу это довольно странным. Есть ли какая-то причина, по которой я веду себя так, когда поток github, кажется, подтверждает, что пакетный метод работает быстрее. Является ли моя установка слишком простой, что накладные расходы убивают какую-либо эффективность? Любая помощь будет очень признательна.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...