tf. while_loop с гибкими номерами строк на одну итерацию - PullRequest
0 голосов
/ 29 апреля 2018

Я пытаюсь заполнить 2d массив в tf.while_loop. Дело в том, что результат моих вычислений на каждой итерации возвращает переменное количество строк. Tensorflow, похоже, не позволяет этого.

См. Этот минимальный пример, воспроизводящий проблему:

indices = tf.constant([2, 5, 7, 9])

num_elems = tf.shape(indices)[0]
init_array = tf.TensorArray(tf.float64, size=num_elems)
initial_i = tf.constant(0, dtype='int32')

def loop_body(i, ta):
    # Here if I choose a random rows number, it fails.
    n_rows = tf.random_uniform((), minval=0, maxval=10, dtype=tf.int64)

    # It works with a fixed row number.
    # n_rows = 2

    anchor = tf.random_normal((n_rows, 4))
    ta = ta.write(i, tf.cast(anchor, tf.float64))
    return i+1, ta

_, anchors= tf.while_loop(lambda i, ta: i < num_elems, loop_body, [initial_i, init_array])
anchors = anchors.stack()
anchors = tf.reshape(anchors, shape=(-1, 4))
anchors = tf.identity(anchors, name="anchors")

with tf.Session() as sess:
    result = sess.run(anchors)
    print(result)

Возвращает:

[[ 0.07496446 -0.32444516 -0.47164568  1.10953283]
 [-0.78791034  1.87736523  0.99817699  0.45336106]
 [-0.65860498 -1.1703862  -0.05761402 -0.17642537]
 [ 0.49713874  1.01805222  0.60902107  0.85543454]
 [-1.38755643 -0.70669901  0.34549037 -0.85984546]
 [-1.32419562  0.71003789  0.34984082 -1.39001906]
 [ 2.26691341 -0.63561141  0.38636214  0.02521387]
 [-1.55348766  1.0176425   0.4889268  -0.12093868]]

Я также открыт для альтернативных решений, чтобы заполнить тензор в цикле переменным числом строк на каждой итерации.

1 Ответ

0 голосов
/ 08 мая 2018

Вот вложенное while_loop решение, которое записывает в один TensorArray:

import tensorflow as tf

def make_inner_loop_body(total_size, anchor):

  def _inner_loop_body(j, ta):
    return j + 1, ta.write(total_size + j, anchor[j])

  return _inner_loop_body

def loop_body(i, total_size, ta):
    n_rows = tf.random_uniform((), minval=0, maxval=10, dtype=tf.int32)
    n_rows = tf.Print(n_rows, [n_rows])
    anchor = tf.random_normal((n_rows, 4), dtype=tf.float64)
    _, ta = tf.while_loop(lambda j, ta: j < n_rows,
                          make_inner_loop_body(total_size, anchor),
                          (tf.zeros([], dtype=tf.int32), ta))
    return i+1, total_size + n_rows, ta

_, _, anchors= tf.while_loop(lambda i, total_size, ta: i < 4,
                             loop_body,
                             (tf.zeros([], dtype=tf.int32),
                              tf.zeros([], dtype=tf.int32),
                              tf.TensorArray(tf.float64, size=0,
                                             dynamic_size=True)))
anchors = anchors.stack()
anchors = tf.reshape(anchors, shape=(-1, 4))
anchors = tf.identity(anchors, name="anchors")

with tf.Session() as sess:
    result = sess.run(anchors)
    print("Final shape", result.shape)
    print(result)

Это печатает что-то вроде:

[5]
[5]
[7]
[7]
Final shape (24, 4)

Я предполагаю, что есть какая-то причина, по которой random_normal нужно обработать в while_loop. В противном случае было бы гораздо проще написать:

import tensorflow as tf

n_rows = tf.random_uniform((4,), minval=0, maxval=10, dtype=tf.int32)
anchors = tf.random_normal((tf.reduce_sum(n_rows), 4), dtype=tf.float64)

with tf.Session() as sess:
    result = sess.run(anchors)
    print("Final shape", result.shape)
    print(result)
...