Почему тензор потока map_fn медленнее, чем python для циклов - PullRequest
2 голосов
/ 08 января 2020

Я хочу применить независимое вращение к каждой партии точек в тензоре с формой [batch, n_pts, 3]. Я реализовал это двумя отдельными способами. Первым было преобразовать тензор в массив numpy и основу c python для циклов. Вторая использует тензорные потоки tf.map_fn() для устранения циклов for. Однако, когда я запускаю этот процесс, тензор потока map_fn() медленнее на ~ 100 раз.

У меня вопрос, неправильно ли я использую здесь функцию tf.map_fn(). Когда вы ожидаете получить прирост производительности при использовании tf.map_fn() по сравнению со стандартным numpy / python?

Если я использую его правильно, то я хотел бы знать, почему тензор потока tf.map_fn() так много медленнее.

Мой код для воспроизведения эксперимента:

import time
import numpy as np
import tensorflow as tf
from tensorflow.keras import backend as K


def rotate_tf(pc):

    theta = tf.random.uniform((1,)) * 2.0 * np.pi
    cosval = tf.math.cos(theta)[0]
    sinval = tf.math.sin(theta)[0]

    R = tf.Variable([
        [cosval, -sinval, 0.0],
        [sinval, cosval, 0.0],
        [0.0, 0.0, 1.0]
    ])

    def dot(p):
        return K.dot(R, tf.expand_dims(p, axis=-1))

    return tf.squeeze(tf.map_fn(dot, pc))


def rotate_np(pc):

    theta = np.random.uniform() * 2.0 * np.pi
    cosval = np.cos(theta)
    sinval = np.sin(theta)

    R = np.array([
        [cosval, -sinval, 0.0],
        [sinval, cosval, 0.0],
        [0.0, 0.0, 1.0]
    ])

    for idx, p in enumerate(pc):
        pc[idx] = np.dot(R, p)

    return pc


pts = tf.random.uniform((8, 100, 3))
n = 10

# Start tensorflow map_fn() method
start = time.time()

for i in range(n):
    pts = tf.map_fn(rotate_tf, pts)

print('processed tf in: {:.4f} s'.format(time.time()-start))

# Start numpy method
start = time.time()

for i in range(n):

    pts = pts.numpy()
    for i, p in enumerate(pts):
        pts[i] = rotate_np(p)
    pts = tf.Variable(pts)

print('processed np in: {:.4f} s'.format(time.time()-start))

Вывод для этого:

processed tf in: 3.8427 s
processed np in: 0.0314 s
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...