tensorarray реализует простую форму ввода rnn - PullRequest
0 голосов
/ 16 апреля 2020

Я реализую простой RSN с tensorarray. Я обнаружил, что форму входных данных нужно транспонировать в [sequence_length, feature_dim] при использовании tensorarray.unstack (). Когда я пытался с формой, должен быть [batch_size, sequence_length, feature_dim] в качестве входных данных rnn, ошибка 。。。

Поэтому мой вопрос заключается в том, почему форма tensorarray.unstack () должна быть [T, B, D ] но не [B, T, D]?

Вот мой код:


B = 3  # batch_size

D = 4  # num_units

T = 5  # time_steps

xs = tf.placeholder(shape=[T, B, D], dtype=tf.float32)

with tf.variable_scope('rnn', reuse=):
    gru_cell = tf.nn.rnn_cell.GRUCell(num_units=D)
    cell = tf.nn.rnn_cell.MultiRNNCell([gru_cell])

    outputs_tensor_array = tf.TensorArray(size=T, dtype=tf.float32,
                                          dynamic_size=True,
                                          clear_after_read=False)
    inputs_tensor_array = tf.TensorArray(size=T, dtype=tf.float32)
    inputs_tensor_array.unstack(xs)  # axis=0 default


    def cond(time, output, state):
        return time < T


    def body(time, outputs_tensor_array, state):
        x_t = inputs_tensor_array.read(time)  # t时刻输入
        new_output, new_state = cell.call(x_t, state)
        outputs_ta_t = outputs_tensor_array.write(time, new_output)
        return time + 1, outputs_ta_t, new_state


    time = 0
    state = cell.zero_state(B, dtype=tf.float32)
    time_final, outputs, state_final = tf.while_loop(cond, body, loop_vars=[time, outputs_tensor_array, state])
    final_output = outputs.stack()
x = np.random.randn(T, B, D)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    output_final_, state_final_ = sess.run([final_output, state_final], feed_dict={xs: x})

...