Должен ли я вызывать iter.get_next при использовании набора данных Tensorflow? - PullRequest
0 голосов
/ 25 апреля 2018

Я преобразовал код из системы, основанной на очереди, в набор данных tenorflow. После преобразования я вижу потерю точности и время увеличилось. Я приписываю это неправильной реализации на моем конце, и в настоящее время я пытаюсь устранить проблему, которая может быть проблемой. Теперь методом проб и ошибок в этом преобразовании я сделал несколько предположений, основанных на ряде статей и примеров, с которыми столкнулся, и я просто хотел убедиться, что моя текущая реализация верна и что мои предположения были также.

Ранее у меня было огромное количество изображений, и я собирал их в очередь, затем извлекал из очереди мои 100 изображений, выполнял обработку и суммирование, а затем продолжал. Эта загрузка в память через очередь, по моему мнению, могла стать причиной узкого места, поэтому, когда я услышал об API набора данных, я подумал, что это стоит посмотреть. Итак, теперь я извлекаю всю информацию об изображении и передаю ее своему методу, где затем выполняю пакетную обработку с помощью пакетного метода набора данных. До и после показаны ниже. Я читал, что нет необходимости вызывать iter.get_next для набора данных, поскольку ops будет вызывать его автоматически, однако с точностью, которую я вижу в конце, я сомневаюсь, правда это или нет. В настоящее время, как вы можете видеть, я просто передаю iter.initializer в качестве операции sess.run с другими своими операциями и передаю feed_dict. Любое понимание было бы полезно, поскольку я несколько новичок в этом. Спасибо!

Предыдущая функция выборки при использовании очереди: (Имейте в виду, я бы поставил изображения в очередь в объект blob и передал бы это подмножество этому методу)

def get_summary(self, sess, images, labels, weights, keep_prob = 1.0):
        feed_dict = {self._input_images: images, self._input_labels: labels,
                     self._input_weights: weights, self._is_training: False}
        summary, acc = sess.run([self._summary_op, self._accuracy], feed_dict=feed_dict)

        return summary, acc

Текущая функция выборки с использованием API набора данных: (Теперь перед вызовом этого я заполняю свой объект BLOB-объекта всеми данными и использую приведенные ниже функции пакетной обработки - обратите внимание, что я никогда не выполняю вызов iter.get_next ())

def get_summary(self, sess, images, labels, weights, keep_prob = 1.0, batch_size=32):
        dataset = tf.data.Dataset.from_tensor_slices((self._input_images, self._input_labels,
                                                      self._input_weights)).repeat().batch(batch_size)

        iter = dataset.make_initializable_iterator()
        feed_dict = {self._input_images: images, self._input_labels: labels,
                     self._input_weights: weights, self._is_training: False}
        _, summary, acc = sess.run([iter.initializer, self._summary_op, self._accuracy], feed_dict=feed_dict)

        return summary, acc

1 Ответ

0 голосов
/ 01 мая 2018

Из этого фрагмента кода похоже, что вы никогда не используете значения из iter, поэтому это не должно влиять на ваши итоги.Например, вы должны иметь возможность удалить строки, которые создают итератор, и удалить iter.initializer из списка, переданного в sess.run(), и получить тот же результат.

Чтобы ответить на более широкий вопрос «Должен ли яcall iter.get_next()? ": в TensorFlow на основе графа должно быть соединение потока данных между tf.data.Iterator и тензором / операцией, которую вы передаете sess.run(), чтобы получить значения из этого итератора.Если вы используете низкоуровневый API TensorFlow , самый простой способ добиться этого - вызвать iter.get_next() для получения одного или нескольких tf.Tensor объектов, а затем использовать эти тензоры в качестве входных данных для вашей модели..

Однако, если вы используете API высокого уровня tf.estimator, ваш input_fn может возвратить tf.data.Dataset без создания tf.data.Iterator (или вызова Iterator.get_next()API Estimator позаботится о создании итератора и вызове для вас get_next().

...