Я сталкиваюсь со странным поведением при использовании модуля TF при попытке вывести число с плавающей точкой из пользовательской функции для переноса в TF.
import tensorflow as tf
import numpy as np
def my_func(x):
outX = 1.0
return outX
input = tf.placeholder(tf.float32,shape=(2,))
resultFun = tf.py_func(my_func, [input], tf.float32)
with tf.Session() as sess:
inival=[np.array(0.9),np.array(1.9)]
print(sess.run(resultFun, feed_dict={input: inival}))
Проблема в строке
outX=1.0
, поскольку он генерирует следующее:
InvalidArgumentError (см. Выше для отслеживания): 0-е значение, возвращаемое pyfunc_0, является двойным, но ожидает float
Я пытался форсировать бросок с помощью np.float или float, но с теми же результатами.На самом деле мне также удалось преодолеть это с помощью трюка:
outX = x*0+1.0
, и это работает.Так что это явно зависит от какого-то неправильного выбора.Как решить проблему с кастингом без ярлыков?