TensorFlow - чтение многомерных массивов TFRecord и их пакетирование - PullRequest
0 голосов
/ 20 февраля 2019

Я использую TensorFlow для чтения файлов TFRecord, где я сохранил два двумерных массива и один одномерный массив со значениями с плавающей запятой.Данные хранятся в нескольких файлах TFRecord, представляющих огромный набор данных, разделенных на несколько файлов.Для чтения данных я использую функцию dataset.map () , чтобы получить мои проанализированные объекты с ранее сохраненными формами объектов.

Когда я готовил свои файлы TFRecord, я сохранял эти функции в виде плавающих списков.и сплющил их.Кроме того, я сохранил форму каждого из них, чтобы я мог правильно прочитать его позже, когда я читаю эти файлы.

Когда я перебираю данные без пакетирования данных, я получаю обычные массивы функций.Например: если моя форма объекта для значений X (входных) равна (10, 342), то я получу 10 наборов из 342 значений.Но если я установлю пакет базы данных равным 2, то моя форма объекта внезапно превратится в (1, 10, 342).Это похоже на то, что пакетная обработка TensorFlow не распознает, что существует 10 наборов данных, но затем объединяет их в одно целое.

Мой код для чтения файлов TFRecord выглядит следующим образом:

x_shape = (10, 342)
y_shape = (10, 311)
p_shape = (10,)

def _parse_function(example_proto): 
    keys_to_features = {'X':tf.FixedLenFeature(x_shape, tf.float32),
            'Y':tf.FixedLenFeature(y_shape, tf.float32),
            'P':tf.FixedLenFeature(p_shape, tf.float32)}

    parsed_features = tf.parse_single_example(example_proto, keys_to_features)

    return parsed_features['X'], parsed_features['Y'], parsed_features['P']

def _load_tfrecord():
    training_filenames = [TF_DATABASE_PATH + TF_DATABASE_NAME]

    dataset = tf.data.TFRecordDataset(training_filenames)
    dataset = dataset.map(_parse_function)
    dataset = dataset.batch(2)

    init = tf.global_variables_initializer()
    iterator = dataset.make_initializable_iterator()
    nextElement = iterator.get_next()

    with tf.Session() as sess:          
        sess.run(init)
        sess.run(iterator.initializer)
        currentBatch = sess.run(nextElement)

Итак, мой вопрос, как заставить TensorFlow разделить данные на правильные партии?В этом случае, если у меня есть 10 наборов данных (где у каждого X есть 342 элемента), я хотел бы иметь 5 пакетов по 2 набора в каждом.

Я впервые работаю с TFRecord, поэтому я бы хотелочень ценю вашу помощь!Спасибо.

...