Установить пересечение в Tensorflow - PullRequest
0 голосов
/ 17 ноября 2018

Я хочу проверить, содержится ли какое-либо из заданных значений в разреженном тензоре.Разреженный тензор называется labels и имеет только одно измерение, содержащее список идентификаторов.

В конце концов, это похоже на проблему пересечения простых множеств, поэтому я попробовал это.

sparse_ids = load_ids_as_sparse_tensor()
wanted_ids = tf.constant([34, 56, 12])
intersection = tf.sets.set_intersection(
    wanted_ids,
    tf.cast(sparse_ids.values, tf.int32)
)
contains_any_wanted_ids = tf.not_equal(tf.size(intersection), 0)

Однако я получаю эту ошибку:

ValueError: Shape must be at least rank 2 but is rank 1 for 'DenseToDenseSetOperation' (op: 'DenseToDenseSetOperation') with input shapes: [3], [?].

Есть идеи?

1 Ответ

0 голосов
/ 17 ноября 2018

Следующий код работает.Однако я не уверен, что вы хотите получить результат.

import tensorflow as tf
a = tf.constant([34, 56, 12])
b = tf.constant([56])
intersection = tf.sets.set_intersection(a[None,:],b[None,:])
sess=tf.Session()
sess.run(intersection)

Вывод:

SparseTensorValue (indices = array ([[0, 0]], dtype =int64), значения = массив ([56]), плотность_форм = массив ([1, 1], dtype = int64))

...