Вот вложенное while_loop
решение, которое записывает в один TensorArray
:
import tensorflow as tf
def make_inner_loop_body(total_size, anchor):
def _inner_loop_body(j, ta):
return j + 1, ta.write(total_size + j, anchor[j])
return _inner_loop_body
def loop_body(i, total_size, ta):
n_rows = tf.random_uniform((), minval=0, maxval=10, dtype=tf.int32)
n_rows = tf.Print(n_rows, [n_rows])
anchor = tf.random_normal((n_rows, 4), dtype=tf.float64)
_, ta = tf.while_loop(lambda j, ta: j < n_rows,
make_inner_loop_body(total_size, anchor),
(tf.zeros([], dtype=tf.int32), ta))
return i+1, total_size + n_rows, ta
_, _, anchors= tf.while_loop(lambda i, total_size, ta: i < 4,
loop_body,
(tf.zeros([], dtype=tf.int32),
tf.zeros([], dtype=tf.int32),
tf.TensorArray(tf.float64, size=0,
dynamic_size=True)))
anchors = anchors.stack()
anchors = tf.reshape(anchors, shape=(-1, 4))
anchors = tf.identity(anchors, name="anchors")
with tf.Session() as sess:
result = sess.run(anchors)
print("Final shape", result.shape)
print(result)
Это печатает что-то вроде:
[5]
[5]
[7]
[7]
Final shape (24, 4)
Я предполагаю, что есть какая-то причина, по которой random_normal
нужно обработать в while_loop
. В противном случае было бы гораздо проще написать:
import tensorflow as tf
n_rows = tf.random_uniform((4,), minval=0, maxval=10, dtype=tf.int32)
anchors = tf.random_normal((tf.reduce_sum(n_rows), 4), dtype=tf.float64)
with tf.Session() as sess:
result = sess.run(anchors)
print("Final shape", result.shape)
print(result)