Ниже вы можете найти простую реализацию словаря векторов в Tensorflow. Это также пример использования tf.lookup.experimental.DenseHashTable
и tf.TensorArray
.
Как сказано, векторы не могут храниться в tf.lookup.experimental.DenseHashTable
, и поэтому tf.TensorArray
используется для хранения фактических векторов.
Конечно, это простой пример, и он не включает удаление записей в словаре - операцию, которая потребует некоторого управления свободными ячейками массива. Кроме того, вы должны прочитать на соответствующих страницах API tf.lookup.experimental.DenseHashTable
и tf.TensorArray
, как настроить их для своих нужд.
import tensorflow as tf
class DictionaryOfVectors:
def __init__(self, dtype):
empty_key = tf.constant('')
deleted_key = tf.constant('deleted')
self.ht = tf.lookup.experimental.DenseHashTable(key_dtype=tf.string,
value_dtype=tf.int32,
default_value=-1,
empty_key=empty_key,
deleted_key=deleted_key)
self.ta = tf.TensorArray(dtype, size=0, dynamic_size=True, clear_after_read=False)
self.inserts_counter = 0
@tf.function
def insertOrAssign(self, key, vec):
# Insert the vector to the TensorArray. The write() method returns a new
# TensorArray object with flow that ensures the write occurs. It should be
# used for subsequent operations.
with tf.init_scope():
self.ta = self.ta.write(self.inserts_counter, vec)
# Insert the same counter value to the hash table
self.ht.insert_or_assign(key, self.inserts_counter)
self.inserts_counter += 1
@tf.function
def lookup(self, key):
with tf.init_scope():
index = self.ht.lookup(key)
return self.ta.read(index)
dictionary_of_vectors = DictionaryOfVectors(dtype=tf.float32)
dictionary_of_vectors.insertOrAssign('first', [1,2,3,4,5])
print(dictionary_of_vectors.lookup('first'))
Пример немного сложнее, поскольку методы вставки и поиска украшены @tf.function
. Поскольку методы изменяют переменные, определенные вне них, используется tf.init_scope()
. Вы можете спросить, что изменилось в методе lookup()
, поскольку на самом деле он считывает только таблицу ha sh и массив. Причина в том, что в режиме графика индекс, который возвращается из вызова lookup()
, является Tensor, а в реализации TensorArray есть строка, содержащая if index < 0:
, которая не выполняется с:
OperatorNotAllowedInGraphError : использование tf.Tensor
в качестве Python bool
недопустимо.
Когда мы используем tf.init_scope()
, как объясняется в документации по его API, «код внутри блока init_scope выполняется с активным исполнением даже при отслеживании tf.function
». Так что в этом случае этот индекс не тензор, а скаляр.