Проблема в том, что полиномиальный метод TensorFlow sample()
фактически использует вызовы метода _sample_n()
.Этот метод определен здесь .Как мы видим из кода для выборки из многочлена, код создает матрицу one_hot для каждой строки , а затем сокращает матрицу до вектора путем суммирования по строкам:
math_ops.reduce_sum(array_ops.one_hot(x, depth=k), axis=-2)
Это неэффективно, потому что он использует дополнительную память.Чтобы избежать этого, я использовал функцию tf.scatter_nd
.Вот полностью работоспособный пример:
import tensorflow as tf
import numpy as np
import tensorflow.contrib.distributions as ds
import time
tf.reset_default_graph()
nb_distribution = 100 # number of probabilities distribution
u = np.random.randint(2000, 3500, size=nb_distribution) # define number of counts (vector of size 100 with int in 2000, 3500)
# probsn is a matrix of probability:
# each row of probsn contains a vector of size 30 that sums to 1
probsn = np.random.uniform(size=(nb_distribution, 30))
probsn /= np.sum(probsn, axis=1)[:, None]
counts = tf.Variable(u, dtype=tf.float32)
probs = tf.Variable(tf.convert_to_tensor(probsn.astype(np.float32)))
# sample from the multinomial
dist = ds.Multinomial(total_count=counts, probs=probs)
out = dist.sample()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
res = sess.run(out) # if remove this line the code is slower...
start = time.time()
res = sess.run(out)
print(time.time() - start)
print(np.all(u == np.sum(res, axis=1)))
Этот код занял 0,05 секунды для вычисления
def vmultinomial_sampling(counts, pvals, seed=None):
k = tf.shape(pvals)[1]
logits = tf.expand_dims(tf.log(pvals), 1)
def sample_single(args):
logits_, n_draw_ = args[0], args[1]
x = tf.multinomial(logits_, n_draw_, seed)
indices = tf.cast(tf.reshape(x, [-1,1]), tf.int32)
updates = tf.ones(n_draw_) # tf.shape(indices)[0]
return tf.scatter_nd(indices, updates, [k])
x = tf.map_fn(sample_single, [logits, counts], dtype=tf.float32)
return x
xx = vmultinomial_sampling(u, probsn)
# check = tf.expand_dims(counts, 1) * probs
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
res = sess.run(xx) # if remove this line the code is slower...
start_t = time.time()
res = sess.run(xx)
print(time.time() -start_t)
#print(np.sum(res, axis=1))
print(np.all(u == np.sum(res, axis=1)))
Этот код занял 0,016 секунды
Недостаток в том, что мой код нена самом деле не распараллеливать вычисления (хотя параметр parallel_iterations
по умолчанию равен 10 в map_fn
, установка его в 1 ничего не меняет ...)
Может быть, кто-то найдет что-то лучше, потому чтоон все еще очень медленный по сравнению с реализацией Theano (из-за того, что он не использует преимущества распараллеливания ... и, тем не менее, здесь распараллеливание имеет смысл, поскольку выборка одной строки не зависит от выборки другой ...)