Ошибка тензорного потока при реализации потери покрытия с использованием tf.nn.raw_rnn для последовательности обучения - PullRequest
0 голосов
/ 10 сентября 2018

Я пытаюсь реализовать потерю покрытия (аналогично тому, который упоминался в статье «Как добраться до точки: суммирование с сетями генератора указателей», ссылка - https://arxiv.org/pdf/1704.04368.pdf) в последовательности для обучения последовательности с использованием функции tf.nn.raw_rnn чтобы оштрафовать повторение слов в выводе декодера. Я собираю значения покрытия внутри loop_transition_fn, а затем добавляю его в cross_entropy_loss. Но при этом я получаю следующую ошибку:

INFO:tensorflow:Cannot use 'raw_rnn/rnn/while/Sum_2' as input 
to 'ce_loss/Rank/packed' because 'raw_rnn/rnn/while/Sum_2' is in a while loop.

ce_loss/Rank/packed while context: None
raw_rnn/rnn/while/Sum_2 while context: raw_rnn/rnn/while/while_context
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-18-6ecca84ce9de> in <module>()
      3      logits=dec_logits, labels=tf.one_hot(dec_target, depth=vocab_size, 
      4                                          dtype=tf.float32))
----> 5     loss = tf.reduce_mean(stepwise_cross_entropy)+tf.reduce_sum(cov_loss_arr)
      6 tf.summary.scalar('loss', loss)
      7 merged = tf.summary.merge_all()

~/Installs/tensorflow_nn/tf_venv/lib/python3.5/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    452                 'in a future version' if date is None else ('after %s' % date),
    453                 instructions)
--> 454       return func(*args, **kwargs)
    455     return tf_decorator.make_decorator(func, new_func, 'deprecated',
    456                                        _add_deprecated_arg_notice_to_docstring(

~/Installs/tensorflow_nn/tf_venv/lib/python3.5/site-packages/tensorflow/python/ops/math_ops.py in reduce_sum(input_tensor, axis, keepdims, name, reduction_indices, keep_dims)
   1303                                    input_tensor,
   1304                                    _ReductionDims(input_tensor, axis,
-> 1305                                                   reduction_indices),
   1306                                    keepdims,
   1307                                    name=name))

~/Installs/tensorflow_nn/tf_venv/lib/python3.5/site-packages/tensorflow/python/ops/math_ops.py in _ReductionDims(x, axis, reduction_indices)
   1235 
   1236     # Otherwise, we rely on Range and Rank to do the right thing at run-time.
-> 1237     return range(0, array_ops.rank(x))
   1238 
   1239 

~/Installs/tensorflow_nn/tf_venv/lib/python3.5/site-packages/tensorflow/python/ops/array_ops.py in rank(input, name)
    365   @end_compatibility
    366   """
--> 367   return rank_internal(input, name, optimize=True)
    368 
    369 

~/Installs/tensorflow_nn/tf_venv/lib/python3.5/site-packages/tensorflow/python/ops/array_ops.py in rank_internal(input, name, optimize)
    385       return gen_array_ops.size(input.dense_shape, name=name)
    386     else:
--> 387       input_tensor = ops.convert_to_tensor(input)
    388       input_shape = input_tensor.get_shape()
    389       if optimize and input_shape.ndims is not None:

~/Installs/tensorflow_nn/tf_venv/lib/python3.5/site-packages/tensorflow/python/framework/ops.py in convert_to_tensor(value, dtype, name, preferred_dtype)
    996       name=name,
    997       preferred_dtype=preferred_dtype,
--> 998       as_ref=False)
    999 
   1000 

~/Installs/tensorflow_nn/tf_venv/lib/python3.5/site-packages/tensorflow/python/framework/ops.py in internal_convert_to_tensor(value, dtype, name, as_ref, preferred_dtype, ctx)
   1092 
   1093     if ret is None:
-> 1094       ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
   1095 
   1096     if ret is NotImplemented:

~/Installs/tensorflow_nn/tf_venv/lib/python3.5/site-packages/tensorflow/python/ops/array_ops.py in _autopacking_conversion_function(v, dtype, name, as_ref)
    959   if dtype is not None and dtype != inferred_dtype:
    960     return NotImplemented
--> 961   return _autopacking_helper(v, inferred_dtype, name or "packed")
    962 
    963 

~/Installs/tensorflow_nn/tf_venv/lib/python3.5/site-packages/tensorflow/python/ops/array_ops.py in _autopacking_helper(list_or_tuple, dtype, name)
    922           elems_as_tensors.append(
    923               constant_op.constant(elem, dtype=dtype, name=str(i)))
--> 924       return gen_array_ops.pack(elems_as_tensors, name=scope)
    925     else:
    926       return converted_elems

~/Installs/tensorflow_nn/tf_venv/lib/python3.5/site-packages/tensorflow/python/ops/gen_array_ops.py in pack(values, axis, name)
   4592     axis = _execute.make_int(axis, "axis")
   4593     _, _, _op = _op_def_lib._apply_op_helper(
-> 4594         "Pack", values=values, axis=axis, name=name)
   4595     _result = _op.outputs[:]
   4596     _inputs_flat = _op.inputs

~/Installs/tensorflow_nn/tf_venv/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py in _apply_op_helper(self, op_type_name, name, **keywords)
    785         op = g.create_op(op_type_name, inputs, output_types, name=scope,
    786                          input_types=input_types, attrs=attr_protos,
--> 787                          op_def=op_def)
    788       return output_structure, op_def.is_stateful, op
    789 

~/Installs/tensorflow_nn/tf_venv/lib/python3.5/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    452                 'in a future version' if date is None else ('after %s' % date),
    453                 instructions)
--> 454       return func(*args, **kwargs)
    455     return tf_decorator.make_decorator(func, new_func, 'deprecated',
    456                                        _add_deprecated_arg_notice_to_docstring(

~/Installs/tensorflow_nn/tf_venv/lib/python3.5/site-packages/tensorflow/python/framework/ops.py in create_op(***failed resolving arguments***)
   3153           input_types=input_types,
   3154           original_op=self._default_original_op,
-> 3155           op_def=op_def)
   3156       self._create_op_helper(ret, compute_device=compute_device)
   3157     return ret

~/Installs/tensorflow_nn/tf_venv/lib/python3.5/site-packages/tensorflow/python/framework/ops.py in __init__(self, node_def, g, inputs, output_types, control_inputs, input_types, original_op, op_def)
   1744 
   1745     if not c_op:
-> 1746       self._control_flow_post_processing()
   1747 
   1748   def _control_flow_post_processing(self):

~/Installs/tensorflow_nn/tf_venv/lib/python3.5/site-packages/tensorflow/python/framework/ops.py in _control_flow_post_processing(self)
   1753     """
   1754     for input_tensor in self.inputs:
-> 1755       control_flow_util.CheckInputFromValidContext(self, input_tensor.op)
   1756     if self._control_flow_context is not None:
   1757       self._control_flow_context.AddOp(self)

~/Installs/tensorflow_nn/tf_venv/lib/python3.5/site-packages/tensorflow/python/ops/control_flow_util.py in CheckInputFromValidContext(op, input_op)
    312         input_op.name, "".join(traceback.format_list(input_op.traceback)))
    313     logging.info(log_msg)
--> 314     raise ValueError(error_msg + " See info log for more details.")

ValueError: Cannot use 'raw_rnn/rnn/while/Sum_2' as input to 'ce_loss/Rank/packed' because 'raw_rnn/rnn/while/Sum_2' is in a while loop. See info log for more details.

Вот мой код:

att_till_t = []
cov_loss_arr = []

def loop_fn_transition(time, previous_output, previous_state, 
                       previous_loop_state):
    global att_till_t, cov_loss_arr
    elems_finished = (time >= dec_len)
    if not att_till_t:
        ct = tf.zeros(encoder_hidden_units)
    else:
        ct = tf.reduce_sum(att_till_t, axis=0)
    att_raw = tf.reduce_sum(
        tf.tanh((previous_output*W_a+ct*W_c+b_a)+enc_linear_layer)*v, 
        axis=2, keepdims=True)
    att = tf.nn.softmax(att_raw, axis=0)*tf.cast(elems_finished, tf.float32)
    att = att/tf.reduce_sum(att, axis=0, keepdims=True)
    cov_loss = tf.reduce_sum(tf.minimum(att, ct), [1])
    cov_loss_arr.append(cov_loss)
    att_till_t.append(att)
    tf.summary.scalar('att', tf.argmax(att))
    d2 = tf.reduce_sum(att*enc_ops, axis=0)
    output_logits = tf.add(tf.matmul(d2, W), b)
    pred = tf.argmax(output_logits, axis=1)
    next_inp = tf.nn.embedding_lookup(embedding, pred)
    finished = tf.reduce_all(elems_finished)
    inp = tf.cond(finished, lambda:pad_step_emb, lambda:next_inp)
    state = previous_state
    output = previous_output
    loop_state = None
    return (elems_finished, inp, state, 
           output, loop_state)

def loop_fn(time, previous_output, previous_state, 
            previous_loop_state):
    if previous_state is None: #time == 0
        assert previous_output is None and previous_state is None
        return loop_fn_init()
    else:
        return loop_fn_transition(time, previous_output, 
                previous_state, previous_loop_state)

with tf.variable_scope("raw_rnn", reuse=tf.AUTO_REUSE):
    dec_ops_ta, dec_fs, _ = tf.nn.raw_rnn(dec_cell, loop_fn)
dec_ops = dec_ops_ta.stack()

with tf.variable_scope("dec_ops"):
    dec_max_steps, dec_batch_size, dec_dim = tf.unstack(
                        tf.shape(dec_ops))
    dec_op_flat = tf.reshape(dec_ops, (-1, dec_dim))
    dec_logits_flat = tf.add(tf.matmul(dec_op_flat, W), b)
    dec_logits = tf.reshape(dec_logits_flat, (dec_max_steps, 
                                              dec_batch_size, vocab_size))
    dec_pred = tf.argmax(dec_logits, 2)

with tf.variable_scope("ce_loss"):
    stepwise_cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(
     logits=dec_logits, labels=tf.one_hot(dec_target, depth=vocab_size, 
                                         dtype=tf.float32))
    loss = tf.reduce_mean(stepwise_cross_entropy)+tf.reduce_sum(cov_loss_arr)
tf.summary.scalar('loss', loss)
merged = tf.summary.merge_all()
with tf.variable_scope("optimizer"):
    train_op = tf.train.AdamOptimizer().minimize(loss)

Сделал поиск в Google, но не смог найти подобную проблему. Пожалуйста, предложите несколько советов относительно того, что я должен изменить в своем подходе. Заранее спасибо.

...