Если вы хотите использовать tf.keras
вместо реальных Keras, вы можете создать экземпляр TFRecordDataset
с помощью tf.data
API и передать его непосредственно model.fit()
. Бонус: вы получаете потоковую передачу прямо из хранилища Google Cloud, нет необходимости сначала загружать данные :
# Construct a TFRecordDataset
ds_train tf.data.TFRecordDataset('gs://') # path to TFRecords on GCS
ds_train = ds_train.shuffle(1000).batch(32)
model.fit(ds_train)
Чтобы включить данные проверки, создайте TFRecordDataset
с проверкой TFRecords ипередать его аргументу validation_data
в model.fit()
.Примечание: это возможно с TensorFlow 1.9 .
Последнее замечание: вам необходимо указать аргумент steps_per_epoch
.Хак, который я использую, чтобы узнать общее количество примеров во всех файлах TFRecord, состоит в том, чтобы просто перебрать файлы и сосчитать:
import tensorflow as tf
def n_records(record_list):
"""Get the total number of records in a collection of TFRecords.
Since a TFRecord file is intended to act as a stream of data,
this needs to be done naively by iterating over the file and counting.
See https://stackoverflow.com/questions/40472139
Args:
record_list (list): list of GCS paths to TFRecords files
"""
counter = 0
for f in record_list:
counter +=\
sum(1 for _ in tf.python_io.tf_record_iterator(f))
return counter
, которые вы можете использовать для вычисления steps_per_epoch
:
n_train = n_records([gs://path-to-tfrecords/record1,
gs://path-to-tfrecords/record2])
steps_per_epoch = n_train // batch_size