Странная проблема с использованием TensorFlow - PullRequest
0 голосов
/ 13 февраля 2019

Я сталкиваюсь со странным поведением при использовании модуля TF при попытке вывести число с плавающей точкой из пользовательской функции для переноса в TF.

import tensorflow as tf
import numpy as np

def my_func(x):
  outX = 1.0
  return outX

input = tf.placeholder(tf.float32,shape=(2,))
resultFun = tf.py_func(my_func, [input], tf.float32)

with tf.Session() as sess:
  inival=[np.array(0.9),np.array(1.9)]
  print(sess.run(resultFun, feed_dict={input: inival})) 

Проблема в строке

outX=1.0

, поскольку он генерирует следующее:

InvalidArgumentError (см. Выше для отслеживания): 0-е значение, возвращаемое pyfunc_0, является двойным, но ожидает float

Я пытался форсировать бросок с помощью np.float или float, но с теми же результатами.На самом деле мне также удалось преодолеть это с помощью трюка:

outX = x*0+1.0

, и это работает.Так что это явно зависит от какого-то неправильного выбора.Как решить проблему с кастингом без ярлыков?

1 Ответ

0 голосов
/ 13 февраля 2019

Обычно TensorFlow пытается подогнать собственные значения Python к запрашиваемому типу данных, но иногда это может привести к путанице.В этом случае, так как сам Python не позволяет вам указать, является ли значение 32 или 64-битным, вы можете просто обернуть его с помощью NumPy:

import tensorflow as tf
import numpy as np

def my_func(x):
  outX = np.array(1.0, dtype=np.float32)
  return outX

input = tf.placeholder(tf.float32,shape=(2,))
resultFun = tf.py_func(my_func, [input], tf.float32)

with tf.Session() as sess:
  inival=[np.array(0.9),np.array(1.9)]
  print(sess.run(resultFun, feed_dict={input: inival})) 
  # 0.1
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...