Тип теста float32 в режиме без ожидания TensorFlow - PullRequest
0 голосов
/ 01 февраля 2019

Я хочу проверить тип тензора в неагрессивном режиме.В режиме Eager этот код работает:

tf.enable_eager_execution()
t = tf.random.normal([1, 1])
num_type = t.dtype
print(num_type == tf.float32) # prints `True`

В режиме не-Eager тот же код не работает, и я нашел единственный способ проверить это уродливый str(num_type) == "float32":

sess = tf.Session()

t = sess.run(tf.random.normal([1, 1]))

num_type = t.dtype
print(num_type) # prints `float32`
print(str(num_type) == "float32") # prints `True`
print(num_type == float32) # returns `NameError: name 'float32' is not defined`
print(num_type == tf.float32) # returns `TypeError: data type not understood`

и если я попытаюсь захватить тип тензора в сеансе:

t = tf.random.normal([1, 1])
t_type = t.dtype

num_type = sess.run(t_type)

, тогда я получу:

TypeError: Fetch argument tf.float32 has invalid type <class 'tensorflow.python.framework.dtypes.DType'>, must be a string or Tensor. (Can not convert a DType into a Tensor or Operation.)

Как проверить тип float32 вне нетерпеливый режим?

1 Ответ

0 голосов
/ 01 февраля 2019

После оценки тензора в сеансе у вас есть объект-тензор, а не объект tf.Tensor (который вы получили и используете непосредственно в активном режиме).

Таким образом, ваш тест должен быть:

t.dtype == np.float32
...