Должен ли я возвращать набор данных напрямую или вместо этого использовать итератор one_shot? - PullRequest
0 голосов
/ 04 февраля 2019

Я строю конвейер данных, используя Dataset API, но когда я тренируюсь с несколькими графическими процессорами и возвращаю dataset.make_one_shot_iterator (). Get_next () в своей функции ввода, я получаю ValueError: dataset_fn () должен вернуть tf.data.Набор данных при использовании tf.distribute.Strategy.Я могу следить за сообщением об ошибке и возвращать набор данных напрямую, но я не понимаю цели итератора (). Get_next () и как он работает для обучения на одном или нескольких GPU.

...

    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size = batch_size)
    dataset = dataset.cache()

    dataset = dataset.prefetch(buffer_size=None)

    return dataset.make_one_shot_iterator().get_next()

return _input_fn

1 Ответ

0 голосов
/ 04 февраля 2019

При использовании tf.data со стратегией распространения (которую можно использовать с Keras и tf.Estimator s), ваш ввод fn должен возвращать tf.data.Dataset:

def input_fn():
  dataset = dataset.repeat(num_epochs)
  dataset = dataset.batch(batch_size = batch_size)
  dataset = dataset.cache()

  dataset = dataset.prefetch(buffer_size=None)
  return dataset

...use input_fn...

См. документацию по стратегии распространения.

dataset.make_one_shot_iterator() полезен вне стратегий распространения / библиотек более высокого уровня, например, если вы используете библиотеки более низкого уровня или отлаживаете / тестируете набор данных.Например, вы можете перебрать все элементы набора данных следующим образом:

dataset = ...
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with tf.Session() as sess:
  while True:
    print(sess.run(get_next))
  except tf.errors.OutOfRangeError:
    break
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...