Я хочу проверить тип тензора в неагрессивном режиме.В режиме Eager этот код работает:
tf.enable_eager_execution()
t = tf.random.normal([1, 1])
num_type = t.dtype
print(num_type == tf.float32) # prints `True`
В режиме не-Eager тот же код не работает, и я нашел единственный способ проверить это уродливый str(num_type) == "float32"
:
sess = tf.Session()
t = sess.run(tf.random.normal([1, 1]))
num_type = t.dtype
print(num_type) # prints `float32`
print(str(num_type) == "float32") # prints `True`
print(num_type == float32) # returns `NameError: name 'float32' is not defined`
print(num_type == tf.float32) # returns `TypeError: data type not understood`
и если я попытаюсь захватить тип тензора в сеансе:
t = tf.random.normal([1, 1])
t_type = t.dtype
num_type = sess.run(t_type)
, тогда я получу:
TypeError: Fetch argument tf.float32 has invalid type <class 'tensorflow.python.framework.dtypes.DType'>, must be a string or Tensor. (Can not convert a DType into a Tensor or Operation.)
Как проверить тип float32
вне нетерпеливый режим?