Как сохранить векторы в словаре в тензорном потоке? - PullRequest
0 голосов
/ 28 мая 2020

Кажется, что tf.lookup.experimental.DenseHashTable не может содержать векторы, и я не смог найти примеров, как его использовать.

1 Ответ

0 голосов
/ 28 мая 2020

Ниже вы можете найти простую реализацию словаря векторов в Tensorflow. Это также пример использования tf.lookup.experimental.DenseHashTable и tf.TensorArray.

Как сказано, векторы не могут храниться в tf.lookup.experimental.DenseHashTable, и поэтому tf.TensorArray используется для хранения фактических векторов.

Конечно, это простой пример, и он не включает удаление записей в словаре - операцию, которая потребует некоторого управления свободными ячейками массива. Кроме того, вы должны прочитать на соответствующих страницах API tf.lookup.experimental.DenseHashTable и tf.TensorArray, как настроить их для своих нужд.

import tensorflow as tf


class DictionaryOfVectors:

  def __init__(self, dtype):
    empty_key = tf.constant('')
    deleted_key = tf.constant('deleted')

    self.ht = tf.lookup.experimental.DenseHashTable(key_dtype=tf.string,
                                                    value_dtype=tf.int32,
                                                    default_value=-1,
                                                    empty_key=empty_key,
                                                    deleted_key=deleted_key)
    self.ta = tf.TensorArray(dtype, size=0, dynamic_size=True, clear_after_read=False)
    self.inserts_counter = 0

  @tf.function
  def insertOrAssign(self, key, vec):
    # Insert the vector to the TensorArray. The write() method returns a new
    # TensorArray object with flow that ensures the write occurs. It should be 
    # used for subsequent operations.
    with tf.init_scope():
      self.ta = self.ta.write(self.inserts_counter, vec)

      # Insert the same counter value to the hash table
      self.ht.insert_or_assign(key, self.inserts_counter)
      self.inserts_counter += 1

  @tf.function
  def lookup(self, key):
    with tf.init_scope():
      index = self.ht.lookup(key)
      return self.ta.read(index)

dictionary_of_vectors = DictionaryOfVectors(dtype=tf.float32)
dictionary_of_vectors.insertOrAssign('first', [1,2,3,4,5])
print(dictionary_of_vectors.lookup('first'))

Пример немного сложнее, поскольку методы вставки и поиска украшены @tf.function. Поскольку методы изменяют переменные, определенные вне них, используется tf.init_scope(). Вы можете спросить, что изменилось в методе lookup(), поскольку на самом деле он считывает только таблицу ha sh и массив. Причина в том, что в режиме графика индекс, который возвращается из вызова lookup(), является Tensor, а в реализации TensorArray есть строка, содержащая if index < 0:, которая не выполняется с:

OperatorNotAllowedInGraphError : использование tf.Tensor в качестве Python bool недопустимо.

Когда мы используем tf.init_scope(), как объясняется в документации по его API, «код внутри блока init_scope выполняется с активным исполнением даже при отслеживании tf.function». Так что в этом случае этот индекс не тензор, а скаляр.

...