Способ сделать это заключается в следующем.Это немного сложно, но работает отлично.
Некоторые функции получены из https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn.py.Я рекомендую делать импорт так же, как приведенный выше код.Основной код выглядит следующим образом:
def _dynamic_average_loop(inputs,
initial_state,
parallel_iterations,
swap_memory,
sequence_length=None,
dtype=None):
state = initial_state
assert isinstance(parallel_iterations, int), "parallel_iterations must be int"
flat_input = nest.flatten(inputs)
embedding_dimension = tf.shape(inputs)[2]
flat_output_size = [embedding_dimension]
# Construct an initial output
input_shape = array_ops.shape(flat_input[0])
time_steps = input_shape[0]
batch_size = _best_effort_input_batch_size(flat_input)
inputs_got_shape = tuple(input_.get_shape().with_rank_at_least(3)
for input_ in flat_input)
const_time_steps, const_batch_size = inputs_got_shape[0].as_list()[:2]
for shape in inputs_got_shape:
if not shape[2:].is_fully_defined():
raise ValueError(
"Input size (depth of inputs) must be accessible via shape inference,"
" but saw value None.")
got_time_steps = shape[0].value
got_batch_size = shape[1].value
if const_time_steps != got_time_steps:
raise ValueError(
"Time steps is not the same for all the elements in the input in a "
"batch.")
if const_batch_size != got_batch_size:
raise ValueError(
"Batch_size is not the same for all the elements in the input.")
# Prepare dynamic conditional copying of state & output
def _create_zero_arrays(size):
size = _concat(batch_size, size)
return array_ops.zeros(
array_ops.stack(size), _infer_state_dtype(dtype, state))
flat_zero_output = tuple(_create_zero_arrays(output)
for output in flat_output_size)
zero_output = nest.pack_sequence_as(structure=embedding_dimension,
flat_sequence=flat_zero_output)
if sequence_length is not None:
min_sequence_length = math_ops.reduce_min(sequence_length)
max_sequence_length = math_ops.reduce_max(sequence_length)
else:
max_sequence_length = time_steps
time = array_ops.constant(0, dtype=dtypes.int32, name="time")
with ops.name_scope("dynamic_rnn") as scope:
base_name = scope
def _create_ta(name, element_shape, dtype):
return tensor_array_ops.TensorArray(dtype=dtype,
size=time_steps,
element_shape=element_shape,
tensor_array_name=base_name + name)
in_graph_mode = not context.executing_eagerly()
if in_graph_mode:
output_ta = tuple(
_create_ta(
"output_%d" % i,
element_shape=(tensor_shape.TensorShape([const_batch_size])
.concatenate(
_maybe_tensor_shape_from_tensor(out_size))),
dtype=_infer_state_dtype(dtype, state))
for i, out_size in enumerate(flat_output_size))
input_ta = tuple(
_create_ta(
"input_%d" % i,
element_shape=flat_input_i.shape[1:],
dtype=flat_input_i.dtype)
for i, flat_input_i in enumerate(flat_input))
input_ta = tuple(ta.unstack(input_)
for ta, input_ in zip(input_ta, flat_input))
else:
output_ta = tuple([0 for _ in range(time_steps.numpy())]
for i in range(len(flat_output_size)))
input_ta = flat_input
def tf_average(A, B):
return A+B
def _time_step(time, output_ta_t, state):
input_t = tuple(ta.read(time) for ta in input_ta)
# Restore some shape information
for input_, shape in zip(input_t, inputs_got_shape):
input_.set_shape(shape[1:])
input_t = nest.pack_sequence_as(structure=inputs, flat_sequence=input_t)
flat_state = nest.flatten(state)
flat_zero_output = nest.flatten(zero_output)
# Vector describing which batch entries are finished.
copy_cond = time >= sequence_length
def _copy_one_through(output, new_output):
# Otherwise propagate the old or the new value.
with ops.colocate_with(new_output):
return array_ops.where(copy_cond, output, new_output)
the_average = tf_average(input_t, state)
the_average_updated = _copy_one_through(zero_output, the_average)
the_average_last_state = _copy_one_through(state, the_average)
for output, flat_output in zip([the_average_updated], flat_zero_output):
output.set_shape(flat_output.get_shape())
final_output = nest.pack_sequence_as(structure=zero_output, flat_sequence=[the_average_updated])
output_ta_t = tuple(ta.write(time, out) for ta, out in zip(output_ta_t, [final_output]))
return (time + 1, output_ta_t, the_average_last_state)
if in_graph_mode:
# Make sure that we run at least 1 step, if necessary, to ensure
# the TensorArrays pick up the dynamic shape.
loop_bound = math_ops.minimum(
time_steps, math_ops.maximum(1, max_sequence_length))
else:
# Using max_sequence_length isn't currently supported in the Eager branch.
loop_bound = time_steps
_, output_final_ta, final_state = control_flow_ops.while_loop(
cond=lambda time, *_: time < loop_bound,
body=_time_step,
loop_vars=(time, output_ta, state),
parallel_iterations=parallel_iterations,
maximum_iterations=time_steps,
swap_memory=swap_memory)
final_outputs = tuple(ta.stack() for ta in output_final_ta)
# Restore some shape information
for output, output_size in zip(final_outputs, flat_output_size):
shape = _concat(
[const_time_steps, const_batch_size], output_size, static=True)
output.set_shape(shape)
final_outputs = nest.pack_sequence_as(structure=embedding_dimension,
flat_sequence=final_outputs)
return final_outputs , final_state
def dynamic_average(inputs, sequence_length=None, initial_state=None,
dtype=None, parallel_iterations=None, swap_memory=False,
time_major=False, scope=None):
with vs.variable_scope(scope or "rnn") as varscope:
# Create a new scope in which the caching device is either
# determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN.
if _should_cache():
if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device)
# By default, time_major==False and inputs are batch-major: shaped
# [batch, time, depth]
# For internal calculations, we transpose to [time, batch, depth]
flat_input = nest.flatten(inputs)
embedding_dimension = tf.shape(inputs)[2]
if not time_major:
# (B,T,D) => (T,B,D)
flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input]
flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input)
parallel_iterations = parallel_iterations or 32
if sequence_length is not None:
sequence_length = math_ops.to_int32(sequence_length)
if sequence_length.get_shape().ndims not in (None, 1):
raise ValueError(
"sequence_length must be a vector of length batch_size, "
"but saw shape: %s" % sequence_length.get_shape())
sequence_length = array_ops.identity( # Just to find it in the graph.
sequence_length, name="sequence_length")
batch_size = _best_effort_input_batch_size(flat_input)
state = tf.zeros(shape=(batch_size, embedding_dimension))
def _assert_has_shape(x, shape):
x_shape = array_ops.shape(x)
packed_shape = array_ops.stack(shape)
return control_flow_ops.Assert(
math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)),
["Expected shape for Tensor %s is " % x.name,
packed_shape, " but saw shape: ", x_shape])
if not context.executing_eagerly() and sequence_length is not None:
# Perform some shape validation
with ops.control_dependencies(
[_assert_has_shape(sequence_length, [batch_size])]):
sequence_length = array_ops.identity(
sequence_length, name="CheckSeqLen")
inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input)
(outputs, final_state) = _dynamic_average_loop(
inputs,
state,
parallel_iterations=parallel_iterations,
swap_memory=swap_memory,
sequence_length=sequence_length,
dtype=dtype)
if not time_major:
outputs = nest.map_structure(_transpose_batch_time, outputs)
return outputs, final_state
Это основной код.Таким образом, чтобы найти сумму трехмерной матрицы переменной длины, как в RNN, мы можем проверить ее следующим образом:
tf.reset_default_graph()
the_inputs = np.random.uniform(-1,1,(30,50,111)).astype(np.float32)
the_length = np.random.randint(50, size=30)
the_input_tensor = tf.convert_to_tensor(the_inputs)
the_length_tensor = tf.convert_to_tensor(the_length)
outputs, final_state = dynamic_average(inputs=the_input_tensor,
sequence_length=the_length_tensor)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
outputs_result , final_state_result = sess.run((outputs, final_state))
print("Testing")
for index in range(len(the_inputs)):
print(the_inputs[index,:,:][:the_length[index]].sum(axis=0) == final_state_result[index])
print('------------------------------------------------------------------')