Как решить InvalidArgumentError для dynamic_partition в тензорном потоке? - PullRequest
0 голосов
/ 07 октября 2018

Я пытался использовать tf.dynamic_partition() для замены tf.gather(), чтобы неявно преобразовывать разреженное представление в плотную матрицу.

Вот мой код,

# edge_source_states = tf.gather(params=node_states_per_layer[-1], indices=edge_sources)
partitions0 = tf.reduce_sum(tf.one_hot(edge_sources, tf.shape(node_states_per_layer[-1])[0], dtype='int32'),
                                                        0)
edge_source_states = tf.dynamic_partition(node_states_per_layer[-1], partitions0, 2)
edge_source_states = edge_source_states[1]

Аннотацияоригинальное использование tf.gather().Не было ошибки при использовании tf.gather(), единственная проблема в том, что он преобразовывает разреженное представление в плотную матрицу и, таким образом, потребляет много памяти.
Однако, когда я вместо этого использую метод tf.dynamic_partition(), я получаю ошибку:

InvalidArgumentError (see above for traceback): partitions[21] = 2 is not in [0, 2)

И в соответствии с трассировкой, эта ошибка вызвана предложением:

edge_source_states = tf.dynamic_partition(node_states_per_layer[-1], partitions0, 2)

Как новичок, я действительно не могу понять это.

Мои проблемы:

1) Я думал, что мои новые коды, использующие tf.dynamic_partition(), функционально эквивалентны исходному коду, использующему tf.gather().Так почему же возникает ошибка?

2) Является ли tf.dynamic_partition() решением для избежания неявного преобразования из разреженного в плотное, как tf.gather()?Есть ли другие решения?Мне действительно нужно строго контролировать потребление памяти.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...