tf.string_input_producer выдает одну эпоху даже после установки эпох более одной - PullRequest
0 голосов
/ 24 октября 2018

Мой входной API использует tf.string_input_producer вместе с tf.parse_single_sequence_example.Когда я устанавливаю num_epochs> 1 в tf.string_input_producer, моя очередь заканчивается только после одной эпохи.

Это ожидаемое поведение или я совершаю какую-то ошибку?Вот соответствующий код:

class TFRecordReader():



    def __init__(self):
        #some code....

    def execute_queue(self, tensor_queue, exception_message: str, log_dir_path=None):
        import os
        if log_dir_path is None:
            path = os.path.abspath('../../audio_log_dir/')
        else:
            path = log_dir_path
        writer = self._summary_file_writer(path)
        coord, thread = self._coord_thread()
        print('should_stop: ', coord.should_stop())

        if not coord.should_stop():
            try:
                if self._data_v is None:
                    self._data_v = self._parse_tensor(tensor_queue)

                 return self._data_v
            except self.tf.errors.OutOfRangeError:
                print(exception_message)
            finally:
                coord.request_stop()
                coord.join(thread)
                writer.close()


    def single_sequence_batch(self,
                              tf_record_path,
                              feature_map,
                              parse_function,
                              num_epochs=None,
                              tf_record_compression=None,
                              queue_completion_message='Data Exhausted!',
                              log_dir_path=None
                              ):
        self.feature_map = feature_map
        self.parse_func = parse_function
        batch = self._single_sequence_batch(tf_record_path=tf_record_path,
                                        num_epochs=num_epochs,
                                        tf_record_compression=tf_record_compression)
        data_queue = self.execute_queue(batch, queue_completion_message, log_dir_path=log_dir_path)
        return data_queue

def _test_single_sequence_batch(num_epochs=1):
    tfr_path = r'C:/audio_tfrecord/audioapi.tfrecord'
    reader = TFRecordReader()
    data_q = reader.single_sequence_batch(tf_record_path=tfr_path,
                                      feature_map=feature_mapping,
                                      parse_function=parse_func,
                                      tf_record_compression=True,
                                      num_epochs=num_epochs)
    print(len(data_q))
    c = 0
    try:
        for i in range(num_epochs):
            val = reader.session.run(data_q)
            print(val)
            c += 1
    except tf.errors.OutOfRangeError:
        print("Total Examples :", c)
        print('Finished!')
...