Пакетная предварительная обработка при работе с наборами данных в keras - PullRequest
0 голосов
/ 12 июня 2019

У меня есть несколько примеров матрицы данных переменной длины и связанной с ней метки, и я хочу обучить ее работе с сетью LSTM.Я знаю, что должен заполнять выборки данных (например, используя keras.preprocessing.sequence.pad_sequences), по крайней мере, для каждого пакета, и я сделал это успешно для подачи в сеть пустых массивов, но я не знаю, как это сделать с наборами данных TFRecord.

У меня есть типичный код чтения для моего файла TFRecord:

featuresDict = {'data': tf.FixedLenSequenceFeature([], dtype=tf.string),
                'dataShape': tf.FixedLenSequenceFeature([], dtype=tf.int64),
                'label': tf.FixedLenSequenceFeature([], dtype=tf.int64)
               }

def parse_tfrecord(example):
    context, features = tf.parse_single_sequence_example(example, sequence_features=featuresDict)   
    label = features['label']
    data_shape = features['dataShape']
    data = tf.decode_raw(features['data'], tf.int64)
    data = tf.reshape(data, data_shape)
    return label, data

def DataGenerator(fileName, numEpochs=None, batchSize=None):    
  dataset = tf.data.TFRecordDataset(fileName, compression_type='GZIP')
  dataset = dataset.map(parse_tfrecord)
  dataset = dataset.batch(batchSize)
  dataset = dataset.repeat(numEpochs)
  return dataset

Я могу проанализировать каждый пример и сгенерировать исходную матрицу данных и метки.Затем функция DataGenerator определяет набор данных и устанавливает пакетные и повторные функции этого.Затем я создаю объект DataGenerator и использую его для соответствия моей модели:

train_data = DataGenerator(fileName='train.gz', numEpochs=epochs, batchSize=batch_size)
model.fit(train_data, epochs=epochs, steps_per_epoch = train_steps, ...)

Где я могу поместить функцию заполнения в коде?В общем, если я хочу выполнить предварительную обработку на уровне пакета с API набора данных, как я могу это сделать?

...