Я дам вам несколько разных методов для реализации этого. Я думаю, что наиболее очевидным решением является использование tf.scan
:
import tensorflow as tf
def apply_momentum_scan(m, p, axis=0):
# Put axis first
axis = tf.convert_to_tensor(axis, dtype=tf.int32)
perm = tf.concat([[axis], tf.range(axis), tf.range(axis + 1, tf.rank(m))], axis=0)
m_t = tf.transpose(m, perm)
# Do computation
res_t = tf.scan(lambda a, x: a * p + x, m_t)
# Undo transpose
perm_t = tf.concat([tf.range(1, axis + 1), [0], tf.range(axis + 1, tf.rank(m))], axis=0)
return tf.transpose(res_t, perm_t)
Однако вы также можете реализовать это как конкретный матричный продукт, если вы строите матрицу экспоненциальных факторов:
import tensorflow as tf
def apply_momentum_matmul(m, p, axis=0):
# Put axis first and reshape
m = tf.convert_to_tensor(m)
p = tf.convert_to_tensor(p)
axis = tf.convert_to_tensor(axis, dtype=tf.int32)
perm = tf.concat([[axis], tf.range(axis), tf.range(axis + 1, tf.rank(m))], axis=0)
m_t = tf.transpose(m, perm)
shape_t = tf.shape(m_t)
m_tr = tf.reshape(m_t, [shape_t[0], -1])
# Build factors matrix
r = tf.range(tf.shape(m_tr)[0])
p_tr = tf.linalg.band_part(p ** tf.dtypes.cast(tf.expand_dims(r, 1) - r, p.dtype), -1, 0)
# Do computation
res_tr = p_tr @ m_tr
# Reshape back and undo transpose
res_t = tf.reshape(res_tr, shape_t)
perm_t = tf.concat([tf.range(1, axis + 1), [0], tf.range(axis + 1, tf.rank(m))], axis=0)
return tf.transpose(res_t, perm_t)
Это также можно переписать, чтобы избежать первого переноса (который в TensorFlow стоит) с tf.tensordot
:
import tensorflow as tf
def apply_momentum_tensordot(m, p, axis=0):
# Put axis first and reshape
m = tf.convert_to_tensor(m)
# Build factors matrix
r = tf.range(tf.shape(m)[axis])
p_mat = tf.linalg.band_part(p ** tf.dtypes.cast(tf.expand_dims(r, 1) - r, p.dtype), -1, 0)
# Do computation
res_t = tf.linalg.tensordot(m, p_mat, axes=[[axis], [1]])
# Transpose
last_dim = tf.rank(res_t) - 1
perm_t = tf.concat([tf.range(axis), [last_dim], tf.range(axis, last_dim)], axis=0)
return tf.transpose(res_t, perm_t)
Три функции будут использовать аналогичным образом:
import tensorflow as tf
p = tf.Variable(0.5, dtype=tf.float32)
m = tf.constant([[0, 1, 2, 3, 4],
[1, 3, 5, 7, 10],
[1, 1, 1, -1, 0]], tf.float32)
# apply_momentum is one of the functions above
print(apply_momentum(m, p, axis=0).numpy())
# [[ 0. 1. 2. 3. 4. ]
# [ 1. 3.5 6. 8.5 12. ]
# [ 1.5 2.75 4. 3.25 6. ]]
print(apply_momentum(m, p, axis=1).numpy())
# [[ 0. 1. 2.5 4.25 6.125 ]
# [ 1. 3.5 6.75 10.375 15.1875]
# [ 1. 1.5 1.75 -0.125 -0.0625]]
Использование матричного произведения более асимптотически сложно, но может быть быстрее сканирования. Вот небольшой тест:
import tensorflow as tf
import numpy as np
# Make test data
tf.random.set_seed(0)
p = tf.constant(0.5, dtype=tf.float32)
m = tf.random.uniform([100, 30, 50], dtype=tf.float32)
# Axis 0
print(np.allclose(apply_momentum_scan(m, p, 0).numpy(), apply_momentum_matmul(m, p, 0).numpy()))
# True
print(np.allclose(apply_momentum_scan(m, p, 0).numpy(), apply_momentum_tensordot(m, p, 0).numpy()))
# True
%timeit apply_momentum_scan(m, p, 0)
# 11.5 ms ± 610 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit apply_momentum_matmul(m, p, 0)
# 1.36 ms ± 18.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit apply_momentum_tensordot(m, p, 0)
# 1.62 ms ± 7.39 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
# Axis 1
print(np.allclose(apply_momentum_scan(m, p, 1).numpy(), apply_momentum_matmul(m, p, 1).numpy()))
# True
print(np.allclose(apply_momentum_scan(m, p, 1).numpy(), apply_momentum_tensordot(m, p, 1).numpy()))
# True
%timeit apply_momentum_scan(m, p, 1)
# 4.27 ms ± 60.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit apply_momentum_matmul(m, p, 1)
# 1.27 ms ± 36.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit apply_momentum_tensordot(m, p, 1)
# 1.2 ms ± 11.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
# Axis 2
print(np.allclose(apply_momentum_scan(m, p, 2).numpy(), apply_momentum_matmul(m, p, 2).numpy()))
# True
print(np.allclose(apply_momentum_scan(m, p, 2).numpy(), apply_momentum_tensordot(m, p, 2).numpy()))
# True
%timeit apply_momentum_scan(m, p, 2)
# 6.29 ms ± 64.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit apply_momentum_matmul(m, p, 2)
# 1.41 ms ± 21.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit apply_momentum_tensordot(m, p, 2)
# 1.05 ms ± 26 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Итак, матричный продукт, похоже, выиграл. Давайте посмотрим, если это масштабируется:
import tensorflow as tf
import numpy as np
# Make test data
tf.random.set_seed(0)
p = tf.constant(0.5, dtype=tf.float32)
m = tf.random.uniform([1000, 300, 500], dtype=tf.float32)
# Axis 0
print(np.allclose(apply_momentum_scan(m, p, 0).numpy(), apply_momentum_matmul(m, p, 0).numpy()))
# True
print(np.allclose(apply_momentum_scan(m, p, 0).numpy(), apply_momentum_tensordot(m, p, 0).numpy()))
# True
%timeit apply_momentum_scan(m, p, 0)
# 784 ms ± 6.78 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit apply_momentum_matmul(m, p, 0)
# 1.13 s ± 76.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit apply_momentum_tensordot(m, p, 0)
# 1.3 s ± 27 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# Axis 1
print(np.allclose(apply_momentum_scan(m, p, 1).numpy(), apply_momentum_matmul(m, p, 1).numpy()))
# True
print(np.allclose(apply_momentum_scan(m, p, 1).numpy(), apply_momentum_tensordot(m, p, 1).numpy()))
# True
%timeit apply_momentum_scan(m, p, 1)
# 852 ms ± 12.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit apply_momentum_matmul(m, p, 1)
# 659 ms ± 10.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit apply_momentum_tensordot(m, p, 1)
# 741 ms ± 19.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# Axis 2
print(np.allclose(apply_momentum_scan(m, p, 2).numpy(), apply_momentum_matmul(m, p, 2).numpy()))
# True
print(np.allclose(apply_momentum_scan(m, p, 2).numpy(), apply_momentum_tensordot(m, p, 2).numpy()))
# True
%timeit apply_momentum_scan(m, p, 2)
# 1.06 s ± 16.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit apply_momentum_matmul(m, p, 2)
# 924 ms ± 17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit apply_momentum_tensordot(m, p, 2)
# 483 ms ± 10.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Ну, теперь это уже не так ясно. Сканирование все еще не супер быстрое, но матричные продукты иногда медленнее. Как вы можете себе представить, если вы go для еще больших тензоров, сложность матричных продуктов будет доминировать во временах.
Итак, если вы хотите самое быстрое решение и знаете, что ваши тензоры не станут огромными, используйте один реализации матричного продукта. Если у вас все хорошо с хорошей скоростью, но вы хотите убедиться, что у вас не хватает памяти (матричное решение также требует гораздо больше) и время предсказуемо, вы можете использовать решение для сканирования.
Примечание. Тесты выше были выполнены на CPU, результаты могут значительно отличаться на GPU.