Tensorflow 2.0: как преобразовать из MapDataset (после чтения из TFRecord) в некоторую структуру, которая может быть введена в model.fit - PullRequest
1 голос
/ 09 марта 2020

Я сохранил свои данные обучения и проверки в двух отдельных файлах TFRecord, в которых я храню 4 значения: сигнал A (форма float32 (150,)), сигнал B (форма float32 (150,)), метка (скаляр int64), id (строка). Моя функция синтаксического анализа для чтения:

def _parse_data_function(sample_proto):

    raw_signal_description = {
        'label': tf.io.FixedLenFeature([], tf.int64),
        'id': tf.io.FixedLenFeature([], tf.string),
    }

    for key, item in SIGNALS.items():
        raw_signal_description[key] = tf.io.FixedLenFeature(item, tf.float32)

    # Parse the input tf.Example proto using the dictionary above.
    return tf.io.parse_single_example(sample_proto, raw_signal_description)

, где SIGNALS - отображение словаря имя сигнала -> форма сигнала . Затем я читаю необработанные наборы данных:

training_raw = tf.data.TFRecordDataset(<path to training>), compression_type='GZIP')
val_raw = tf.data.TFRecordDataset(<path to validation>), compression_type='GZIP')

и использую карту для анализа значений:

training_data = training_raw.map(_parse_data_function)
val_data = val_raw.map(_parse_data_function)

Отображение заголовка training_data или val_data, я получаю:

<MapDataset shapes: {Signal A: (150,), Signal B: (150,), id: (), label: ()}, types: {Signal A: tf.float32, Signal B: tf.float32, id: tf.string, label: tf.int64}>

, что вполне соответствует ожиданиям. Я также проверил некоторые значения на предмет согласованности, и они казались правильными.

Теперь к моей проблеме: как мне получить из MapDataset со структурой, подобной словарю, что-то, что может быть в качестве входных данных для модели?

Входными данными для моей модели является пара (сигнал A, метка), хотя в будущем я также буду использовать сигнал B.

Мне показалось, что самым простым способом было создать генератор элементов, которые я хочу. Что-то вроде:

def data_generator(mapdataset):
    for sample in mapdataset:
        yield (sample['Signal A'], sample['label'])

Однако при таком подходе я теряю некоторые удобства наборов данных, такие как пакетная обработка, и также не ясно, как использовать тот же подход для параметра validation_data для model.fit , В идеале, я бы конвертировал только между представлением карты и представлением набора данных, где он перебирает пары тензоров и меток сигнала А.

РЕДАКТИРОВАТЬ: мой конечный продукт должен быть чем-то с заголовком, похожим на: <TensorSliceDataset shapes: ((150,), ()), types: (tf.float32, tf.int64)> Но не обязательно TensorSliceDataset

1 Ответ

0 голосов
/ 09 марта 2020

Вы можете просто сделать это в функции разбора. Например:

def _parse_data_function(sample_proto):

    raw_signal_description = {
        'label': tf.io.FixedLenFeature([], tf.int64),
        'id': tf.io.FixedLenFeature([], tf.string),
    }

    for key, item in SIGNALS.items():
        raw_signal_description[key] = tf.io.FixedLenFeature(item, tf.float32)

    # Parse the input tf.Example proto using the dictionary above.
    parsed = tf.io.parse_single_example(sample_proto, raw_signal_description)

    return parsed['Signal A'], parsed['label']

Если вы map эту функцию над TFRecordDataset, у вас будет набор данных кортежей (signal_a, label) вместо набора словарей. Вы должны быть в состоянии поместить это в model.fit напрямую.

...