Как переключить итератор модели между поездом и проверкой наборов данных? - PullRequest
0 голосов
/ 31 января 2019

Я изучаю TensorFlow "более низкий API", где вы вручную определяете слои с помощью tf.layers, создаете наборы данных и итераторы, а также запускаете циклы для обучения и проверки модели.Я пытаюсь провести обучение и проверку.К сожалению, я сталкиваюсь с ошибками при попытке переключения между наборами данных обучения и проверки:

Вот что у меня есть:

self.train_it = \
    train_dataset.batch(self.batch_size).make_initializable_iterator()
self.validate_it = \
    train_dataset.batch(self.batch_size).make_initializable_iterator()

...

input_layer = self.train_it.get_next()[0]
hidden1 = tf.layers.dense(
    input_layer,
    ... )

...

with tf.name_scope('train'):
  self.train_op = \
        tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(self.loss)

...

for epo in range(epochs):
  # Train using self.train_it iterator.
  self.sess.run(self.train_it.initializer)
  total_loss = 0
  for iteration in range(n_batches):
    summary, _, batch_loss = self.sess.run([self.merged_summary, \
        self.train_op, self.loss])
    total_loss += batch_loss
  print('   Epoch : {}/{}, Training loss = {:.4f}'. \
            format(epo+1, epochs, total_loss / n_batches))
  # Validate using self.valid_it iterator.
  self.sess.run(self.validate_it.initializer)
  # HOW DO I TELL THE MODEL TO USE self.valid_it INSTEAD OF self.train_it ???

Проблема в том, что в начале я уже говорил моделииспользовать train_it: input_layer = self.train_it.get_next()[0], и теперь я должен сказать ему переключаться между train_it и validate_it каждую эпоху.Я должен что-то упустить в API о том, как это сделать.

1 Ответ

0 голосов
/ 31 января 2019

Я бы использовал переинициализируемый итератор и сделал бы следующее.

train_dataset = train_dataset.batch(batch_size_train)
val_dataset = validation_dataset.batch(batch_size_val)

iterator = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)

train_init_op = iterator.make_initializer(train_dataset)
val_init_op = iterator.make_initializer(val_dataset)

data, labels = iterator.get_next()

Затем свяжите данные и метки в модели.После этого во время тренировки сделайте следующее:

for e in range(epochs):
    sess.run(train_init_op)
    for iteration in range(n_batches_val):
        ....
    sess.run(val_init_op)
    for iteration in range(n_batches_val):
        ....

Дайте мне знать, если найдете что-то непонятное.

...