Идея состоит в том, что get_next
добавляет некоторые операции в график, так что каждый раз, когда вы их оцениваете, вы получаете следующий элемент в наборе данных. На каждой итерации вам просто нужно запускать операции, которые get_next
выполнял, вам не нужно создавать их снова и снова.
Возможно, хороший способ получить интуицию - попытаться написать итераторсебя. Рассмотрим что-то вроде следующего:
import tensorflow as tf
tf.compat.v1.disable_v2_behavior()
# Make an iterator, returns next element and initializer
def iterator_next(data):
data = tf.convert_to_tensor(data)
i = tf.Variable(0)
# Check we are not out of bounds
with tf.control_dependencies([tf.assert_less(i, tf.shape(data)[0])]):
# Get next value
next_val_1 = data[i]
# Update index after the value is read
with tf.control_dependencies([next_val_1]):
i_updated = tf.compat.v1.assign_add(i, 1)
with tf.control_dependencies([i_updated]):
next_val_2 = tf.identity(next_val_1)
return next_val_2, i.initializer
# Test
with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session() as sess:
# Example data
data = tf.constant([1, 2, 3, 4])
# Make operations that give you the next element
next_val, iter_init = iterator_next(data)
# Initialize iterator
sess.run(iter_init)
# Iterate until exception is raised
while True:
try:
print(sess.run(next_val))
# assert throws InvalidArgumentError
except tf.errors.InvalidArgumentError: break
Вывод:
1
2
3
4
Здесь iterator_next
дает вам нечто сравнимое с тем, что даст вам get_next
в итераторе, плюс инициализатороперация. Каждый раз, когда вы запускаете next_val
, вы получаете новый элемент из data
, вам не нужно каждый раз вызывать функцию (как работает next
в Python), вы вызываете ее один раз, а затем оцениваете результат несколько раз. раз.
РЕДАКТИРОВАТЬ: Вышеприведенную функцию iterator_next
также можно упростить до следующего:
def iterator_next(data):
data = tf.convert_to_tensor(data)
# Start from -1
i = tf.Variable(-1)
# First increment i
i_updated = tf.compat.v1.assign_add(i, 1)
with tf.control_dependencies([i_updated]):
# Check i is not out of bounds
with tf.control_dependencies([tf.assert_less(i, tf.shape(data)[0])]):
# Get next value
next_val = data[i]
return next_val, i.initializer
или даже проще:
def iterator_next(data):
data = tf.convert_to_tensor(data)
i = tf.Variable(-1)
i_updated = tf.compat.v1.assign_add(i, 1)
# Using i_updated directly as a value is equivalent to using i with
# a control dependency to i_updated
with tf.control_dependencies([tf.assert_less(i_updated, tf.shape(data)[0])]):
next_val = data[i_updated]
return next_val, i.initializer