Я использую tf.scatter_nd
для обновления комплексных значений по некоторому индексу.
Кажется, что эта функция как-то складывается из реальной и мнимой частей. Мой вопрос заключается в том, как заставить его работать с заполнителями. Вот минимальный рабочий пример, где переменные b
и e
должны иметь одинаковые значения.
import tensorflow as tf
import numpy as np
tf.reset_default_graph()
update=np.asarray([1.+2j])
idx=tf.constant( [[0]])
shp=tf.constant([1])
# works with constants
a=tf.constant(update)
b=tf.scatter_nd(idx,a,shp)
with tf.Session() as sess:
print sess.run(b) # correct output: 1.+2j
#Does not work with placeholders
d=tf.placeholder(tf.complex128)
e=tf.scatter_nd(idx,d,shp)
with tf.Session() as sess:
print sess.run(e,feed_dict={d:update}) # WRONG output: 3.+0j
Я использую версию графического процессора Anaconda python 2.7 + TensorFlow 1.7, установленную с помощью команды conda.
Edit:
Проблема возникает при запуске кода на графическом процессоре. Версия процессора работает правильно.
Вот обновленный код для воспроизведения проблемы в TensorFlow-GPU 1.8, установленной с использованием Anaconda Python 2.7.
import tensorflow as tf
import numpy as np
tf.reset_default_graph()
update=np.asarray([1.+2j])
idx=tf.constant( [[0]])
shp=tf.constant([1])
a=tf.placeholder(tf.complex128)
with tf.device("/cpu:0"):
b=tf.scatter_nd(idx,a,shp)
with tf.device("/gpu:0"):
c=tf.scatter_nd(idx,a,shp)
with tf.Session() as sess:
print 'Correct output on CPU', sess.run(b,feed_dict={a:update})
print 'Wrong output on GPU',sess.run(c,feed_dict={a:update})
Я видел этот поток и этот поток , но не смог найти способ его разрешения. Есть ли альтернатива tf.scatter_nd
, которая будет работать на GPU?