Если у вас есть хотя бы TensorFlow 1.8.0, вам, вероятно, лучше всего использовать tf.contrib.integrate.odeint_fixed()
, как этот код (проверено):
from __future__ import print_function
import tensorflow as tf
assert tf.VERSION >= "1.8.0", "This code only works with TensorFlow 1.8.0 or later."
def f( y, a ):
return a * a
x = tf.constant( [ 0.0, 1.0, 2, 3, 4 ], dtype = tf.float32 )
i = tf.contrib.integrate.odeint_fixed( f, 0.0, x, method = "rk4" )
with tf.Session() as sess:
res = sess.run( i )
print( res )
выведет:
[0. 0.33333334 2.6666667 9. 21.333334]
, должным образом интегрируя x 2 через интервалы [0, 0] , [0, 1] , [0, 2] , [0, 3] и [0, 4] согласно x = [ 0, 1, 2, 3, 4 ]
выше.(Примитивная функция x 2 равна ⅓ x 3 , поэтому дляпример 4 3 / 3 = 64/3 = 21 ⅓ .)
В противном случае, для более ранних версий TensorFlow, вот как исправить ваш код.
Таким образом, основная проблема заключается в том, что вы должны использовать tf.py_func()
для отображения функции Python (scipy.integrate.quad()
в этом случае) на тензор.tf.map_fn()
отобразит другие операции TensorFlow и пропустит и ожидает тензоры в качестве операндов.Следовательно, x[ 0 ]
не будет никогда простым плавающим числом, оно будет скалярным тензором и scipy.integrate.quad()
не будет знать, что с этим делать.
Вы можетене полностью избавиться от tf.map_fn()
, если только вы не хотите вручную циклически перебирать массивы.
Кроме того, scipy.integrate.quad()
возвращает удвоение (float64)тогда как ваши тензоры - float32.
Я значительно упростил ваш код, потому что у меня нет доступа к остальному, и он выглядит слишком сложным по сравнению с ядром этого вопроса.Будет также выведен следующий код (проверено):
from __future__ import print_function
import tensorflow as tf
from scipy import integrate
def f( a ):
return a * a
def integrated( f, x ):
return tf.map_fn( lambda y: tf.py_func(
lambda z: integrate.quad( f, 0.0, z )[ 0 ], [ y ], tf.float64 ),
x )
x = tf.constant( [ 1.0, 2, 3, 4 ], dtype = tf.float64 )
i = integrated( f, x )
with tf.Session() as sess:
res = sess.run( i )
print( res )
:
[0.33333333 2.66666667 9. 21.33333333]