Я использую tenorflow 1.14 и отлаживаю проблему с NaN. Но tf.is_nan
дает мне результат, отличный от tf.check_numerics
。
Вот мой фрагмент кода:
# q_emb_orig is a tensor with shape[batch, word, embedding]
qemborignan = tf.is_nan(q_emb_orig)
qemboriginf = tf.is_inf(q_emb_orig)
qemnni = tf.reduce_sum(tf.add(tf.cast(qemborignan,tf.int16),tf.cast(qemboriginf,tf.int16)),axis=1)
qemnni = tf.reduce_sum(qemnni,axis=1)
qemnnb = tf.cast(qemnni,tf.bool)
qemnnw = tf.where(qemnnb)
qemnns = tf.squeeze(qemnnw)
qemnnr = tf.gather(q_emb_orig,qemnns,axis=0)
def f1():
printfunc = tf.print(qemnnr,summarize=-1)
with tf.control_dependencies([printfunc]):
return tf.constant(1)
def f2():
return tf.constant(1)
shouldprint = tf.cond(tf.not_equal(tf.size(qemnnr), 0) ,true_fn = f1 ,false_fn=f2)
with tf.control_dependencies([shouldprint,tf.print(qemnni,summarize=-1)]):
q_emb_orig = tf.check_numerics(q_emb_orig,'q emb is nan')
Код logi c таков, что если q_emb_orig
содержит какие-либо nan
или inf
, он будет распечатан. Но в результате возникает только ошибка InvalidArgumentError
, но ничего не печатается.
Я что-то пропустил или график работает в другой последовательности?
Обновление: Предполагается, что nan появился во время обратного распространения, поэтому check_numerics выполняется до is_nan , но я не могу подтвердить.