тензор потока Как я могу решить конфликт dtype между one_hot и sign - PullRequest
0 голосов
/ 23 октября 2018

Моя нейронная сеть имеет следующий вывод:

  • logits - это выходные данные из tanh узлов, поэтому значение является плавающим в пределах (-1,1)
  • action это знак logits
  • one_hot это версия one_hot action, с размером 3 представляет -1, 0 и + 1

ПроблемаМоя функция потерь связана со значениями one_hot, поэтому я строю выходную часть нейронной сети следующим образом:

logits = tf.contrib.layers.fully_connected(outputs, 1, activation_fn=tf.tanh)
action = tf.sign(logits)
one_hot = tf.one_hot(action+1, depth=3)

, и это дает мне TypeError

TypeError: Значение, переданное параметру index, имеет тип DataType float32, которого нет в списке допустимых значений: uint8, int32, int64

Затем я попытался изменить one_hot на:

one_hot = tf.one_hot(tf.cast(action, tf.int32)+1, depth=3)

И я получил еще одну ошибку без градиентов:

ValueError: Градиенты не указаны ни для одной переменной, проверьте график на наличие операций, которые не поддерживают градиенты, между переменными [...]

Есть ли обходные пути, которые я мог бы использовать, чтобы избежать обеих ошибок.Любая помощь приветствуется.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...