tf.is_nan не соответствует tf.check_numerics при проверке NaN - PullRequest
0 голосов
/ 16 января 2020

Я использую tenorflow 1.14 и отлаживаю проблему с NaN. Но tf.is_nan дает мне результат, отличный от tf.check_numerics

Вот мой фрагмент кода:

# q_emb_orig is a tensor with shape[batch, word, embedding]
qemborignan = tf.is_nan(q_emb_orig) 
qemboriginf = tf.is_inf(q_emb_orig)
qemnni = tf.reduce_sum(tf.add(tf.cast(qemborignan,tf.int16),tf.cast(qemboriginf,tf.int16)),axis=1)
qemnni = tf.reduce_sum(qemnni,axis=1)
qemnnb = tf.cast(qemnni,tf.bool)
qemnnw = tf.where(qemnnb)
qemnns = tf.squeeze(qemnnw)
qemnnr = tf.gather(q_emb_orig,qemnns,axis=0)
def f1():
    printfunc = tf.print(qemnnr,summarize=-1)
    with tf.control_dependencies([printfunc]):
        return tf.constant(1)
def f2():
    return tf.constant(1)
shouldprint = tf.cond(tf.not_equal(tf.size(qemnnr), 0) ,true_fn = f1 ,false_fn=f2)
with tf.control_dependencies([shouldprint,tf.print(qemnni,summarize=-1)]):
    q_emb_orig = tf.check_numerics(q_emb_orig,'q emb is nan')

Код logi c таков, что если q_emb_orig содержит какие-либо nan или inf, он будет распечатан. Но в результате возникает только ошибка InvalidArgumentError, но ничего не печатается.

Я что-то пропустил или график работает в другой последовательности?

Обновление: Предполагается, что nan появился во время обратного распространения, поэтому check_numerics выполняется до is_nan , но я не могу подтвердить.

...