Лучший способ обработать терабайты данных на gcloud ml-engine с помощью keras - PullRequest
0 голосов
/ 04 февраля 2019

Я хочу обучить модель примерно 2 ТБ данных изображения в хранилище gcloud.Я сохранил данные изображения как отдельные tfrecords и попытался использовать api данных tenorflow, следуя этому примеру

https://medium.com/@moritzkrger/speeding-up-keras-with-tfrecord-datasets-5464f9836c36

Но похоже, что keras 'model.fit(...) не поддерживает проверку для tfrecordнаборы данных на основе

https://github.com/keras-team/keras/pull/8388

Есть ли лучший подход для обработки больших объемов данных с помощью кера из ml-engine, который мне не хватает?

Большое спасибо!

1 Ответ

0 голосов
/ 04 февраля 2019

Если вы хотите использовать tf.keras вместо реальных Keras, вы можете создать экземпляр TFRecordDataset с помощью tf.data API и передать его непосредственно model.fit(). Бонус: вы получаете потоковую передачу прямо из хранилища Google Cloud, нет необходимости сначала загружать данные :

# Construct a TFRecordDataset
ds_train tf.data.TFRecordDataset('gs://') # path to TFRecords on GCS
ds_train = ds_train.shuffle(1000).batch(32)

model.fit(ds_train)

Чтобы включить данные проверки, создайте TFRecordDataset с проверкой TFRecords ипередать его аргументу validation_data в model.fit().Примечание: это возможно с TensorFlow 1.9 .

Последнее замечание: вам необходимо указать аргумент steps_per_epoch.Хак, который я использую, чтобы узнать общее количество примеров во всех файлах TFRecord, состоит в том, чтобы просто перебрать файлы и сосчитать:

import tensorflow as tf

def n_records(record_list):
    """Get the total number of records in a collection of TFRecords.
    Since a TFRecord file is intended to act as a stream of data,
    this needs to be done naively by iterating over the file and counting.
    See https://stackoverflow.com/questions/40472139

    Args:
        record_list (list): list of GCS paths to TFRecords files
    """
    counter = 0
    for f in record_list:
        counter +=\
            sum(1 for _ in tf.python_io.tf_record_iterator(f))
    return counter 

, которые вы можете использовать для вычисления steps_per_epoch:

n_train = n_records([gs://path-to-tfrecords/record1,
                     gs://path-to-tfrecords/record2])

steps_per_epoch = n_train // batch_size
...