Я пытался использовать tf.dynamic_partition()
для замены tf.gather()
, чтобы неявно преобразовывать разреженное представление в плотную матрицу.
Вот мой код,
# edge_source_states = tf.gather(params=node_states_per_layer[-1], indices=edge_sources)
partitions0 = tf.reduce_sum(tf.one_hot(edge_sources, tf.shape(node_states_per_layer[-1])[0], dtype='int32'),
0)
edge_source_states = tf.dynamic_partition(node_states_per_layer[-1], partitions0, 2)
edge_source_states = edge_source_states[1]
Аннотацияоригинальное использование tf.gather()
.Не было ошибки при использовании tf.gather()
, единственная проблема в том, что он преобразовывает разреженное представление в плотную матрицу и, таким образом, потребляет много памяти.
Однако, когда я вместо этого использую метод tf.dynamic_partition()
, я получаю ошибку:
InvalidArgumentError (see above for traceback): partitions[21] = 2 is not in [0, 2)
И в соответствии с трассировкой, эта ошибка вызвана предложением:
edge_source_states = tf.dynamic_partition(node_states_per_layer[-1], partitions0, 2)
Как новичок, я действительно не могу понять это.
Мои проблемы:
1) Я думал, что мои новые коды, использующие tf.dynamic_partition()
, функционально эквивалентны исходному коду, использующему tf.gather()
.Так почему же возникает ошибка?
2) Является ли tf.dynamic_partition()
решением для избежания неявного преобразования из разреженного в плотное, как tf.gather()
?Есть ли другие решения?Мне действительно нужно строго контролировать потребление памяти.