Как проверить объекты tf.data Dataset? - PullRequest
0 голосов
/ 10 мая 2018

При контрольной точке во время обучения (в случае сбоя и т. Д.) Я сохраняю график и параметры, но неясно, как сделать то же самое с новыми tf.data объектами, используемыми для ввода.

Есть ли простой способ также проверить их так, чтобы я мог продолжить текущую эпоху или восстановить состояние перемешивания (возможно, из начального числа?)

1 Ответ

0 голосов
/ 10 мая 2018

Функция tf.contrib.data.make_saveable_from_iterator() принимает объект tf.data.Iterator и возвращает вам «сохраняемый объект», который можно сохранить с помощью tf.train.Saver. Он сохраняет все состояние итератора, включая все перемешанные данные.

В следующем примере кода показано, как добавить простой итератор к той же контрольной точке, что и для переменных:

ds = tf.data.Dataset.range(10)
iterator = ds.make_initializable_iterator()

# [Build the training graph, using `iterator.get_next()` as the input.]

# Build the iterator SaveableObject.
saveable_obj = tf.contrib.data.make_saveable_from_iterator(iterator)

# Add the SaveableObject to the SAVEABLE_OBJECTS collection so
# it will be saved automatically using a Saver.
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable_obj)

# Create a saver that saves all objects in the `tf.GraphKeys.SAVEABLE_OBJECTS`
# collection.
saver = tf.train.Saver()

with tf.Session() as sess:
  while continue_training:

    # [Perform training.]

    if should_save_checkpoint:
      saver.save(sess, ...)

Обратите внимание, что поддержка контрольных точек итератора в настоящее время (начиная с TensorFlow 1.8) находится в экспериментальном состоянии, поэтому формат контрольных точек может изменяться от одной версии к следующей.

...