Что такое интуиция за методом Iterator.get_next? - PullRequest
1 голос
/ 30 октября 2019

Название метода get_next() немного вводит в заблуждение. Документация гласит:

Возвращает вложенную структуру tf.Tensor s, представляющую следующий элемент.

В графическом режиме обычно следует вызывать этот метод Once ииспользовать его результат в качестве входных данных для другого вычисления. Затем типичный цикл вызовет tf.Session.run в результате этого вычисления. Цикл завершится, когда операция Iterator.get_next() вызовет tf.errors.OutOfRangeError. Следующий скелет показывает, как использовать этот метод при построении обучающего цикла:

dataset = ...  # A `tf.data.Dataset` object.
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

# Build a TensorFlow graph that does something with each element.
loss = model_function(next_element)
optimizer = ...  # A `tf.compat.v1.train.Optimizer` object.
train_op = optimizer.minimize(loss)

with tf.compat.v1.Session() as sess:
  try:
    while True:
      sess.run(train_op)
  except tf.errors.OutOfRangeError:
    pass

Python также имеет функцию с именем next, которую необходимо вызывать каждый разнам нужен следующий элемент итератора. Однако, согласно документации get_next(), приведенной выше, get_next() следует вызывать только один раз, а его результат следует оценивать, вызывая метод run сеанса, так что это немного не интуитивно понятно, потому что я был использованк встроенной функции Python next. В этот сценарий , get_next() также вызывается только, и результат вызова оценивается на каждом этапе вычисления.

Что такое интуиция за get_next() и как онаотличается от next()? Я думаю, что следующий элемент набора данных (или итерируемый итератор) во втором примере, который я связал, извлекается каждый раз, когда результат первого вызова get_next() оценивается путем вызова метода run, но этонемного не интуитивно. Я не понимаю, почему нам не нужно вызывать get_next на каждом шаге вычисления (чтобы получить следующий элемент итерируемого кода), даже после прочтения заметки в документации

ПРИМЕЧАНИЕ. Допустимо вызывать Iterator.get_next() несколько раз, например, когда вы распределяете разные элементы на несколько устройств за один шаг. Однако распространенная ошибка возникает, когда пользователи вызывают Iterator.get_next() в каждой итерации своего цикла обучения. Iterator.get_next() добавляет операции в граф, а выполнение каждой операции распределяет ресурсы (включая потоки);как следствие, вызов его на каждой итерации цикла обучения приводит к замедлению и возможному истощению ресурсов. Чтобы избежать этого, мы регистрируем предупреждение, когда количество использований пересекает фиксированный порог подозрительности.

В общем, неясно, как работает Итератор.

1 Ответ

1 голос
/ 30 октября 2019

Идея состоит в том, что get_next добавляет некоторые операции в график, так что каждый раз, когда вы их оцениваете, вы получаете следующий элемент в наборе данных. На каждой итерации вам просто нужно запускать операции, которые get_next выполнял, вам не нужно создавать их снова и снова.

Возможно, хороший способ получить интуицию - попытаться написать итераторсебя. Рассмотрим что-то вроде следующего:

import tensorflow as tf
tf.compat.v1.disable_v2_behavior()

# Make an iterator, returns next element and initializer
def iterator_next(data):
    data = tf.convert_to_tensor(data)
    i = tf.Variable(0)
    # Check we are not out of bounds
    with tf.control_dependencies([tf.assert_less(i, tf.shape(data)[0])]):
        # Get next value
        next_val_1 = data[i]
    # Update index after the value is read
    with tf.control_dependencies([next_val_1]):
        i_updated = tf.compat.v1.assign_add(i, 1)
        with tf.control_dependencies([i_updated]):
            next_val_2 = tf.identity(next_val_1)
    return next_val_2, i.initializer

# Test
with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session() as sess:
    # Example data
    data = tf.constant([1, 2, 3, 4])
    # Make operations that give you the next element
    next_val, iter_init = iterator_next(data)
    # Initialize iterator
    sess.run(iter_init)
    # Iterate until exception is raised
    while True:
        try:
            print(sess.run(next_val))
        # assert throws InvalidArgumentError
        except tf.errors.InvalidArgumentError: break

Вывод:

1
2
3
4

Здесь iterator_next дает вам нечто сравнимое с тем, что даст вам get_next в итераторе, плюс инициализатороперация. Каждый раз, когда вы запускаете next_val, вы получаете новый элемент из data, вам не нужно каждый раз вызывать функцию (как работает next в Python), вы вызываете ее один раз, а затем оцениваете результат несколько раз. раз.

РЕДАКТИРОВАТЬ: Вышеприведенную функцию iterator_next также можно упростить до следующего:

def iterator_next(data):
    data = tf.convert_to_tensor(data)
    # Start from -1
    i = tf.Variable(-1)
    # First increment i
    i_updated = tf.compat.v1.assign_add(i, 1)
    with tf.control_dependencies([i_updated]):
        # Check i is not out of bounds
        with tf.control_dependencies([tf.assert_less(i, tf.shape(data)[0])]):
            # Get next value
            next_val = data[i]
    return next_val, i.initializer

или даже проще:

def iterator_next(data):
    data = tf.convert_to_tensor(data)
    i = tf.Variable(-1)
    i_updated = tf.compat.v1.assign_add(i, 1)
    # Using i_updated directly as a value is equivalent to using i with
    # a control dependency to i_updated
    with tf.control_dependencies([tf.assert_less(i_updated, tf.shape(data)[0])]):
        next_val = data[i_updated]
    return next_val, i.initializer
...