Мой способ загрузки данных tfrecord в набор данных tf:
def CUB_load_data():
ds = tfds.load('caltech_birds2011', download=False, data_dir='../../datasets/')
train_data = ds['train']
test_data = ds['test']
train_x = []
train_y = []
test_x = []
test_y = []
for i in train_data.__iter__():
resized = cv2.resize(i['image'].numpy(), dsize=(224,224))
train_x.append(resized)
train_y.append(i['label'])
for i in test_data.__iter__():
resized = cv2.resize(i['image'].numpy(), dsize=(224,224))
test_x.append(resized)
test_y.append(i['label'])
return (train_x, train_y), (test_x, test_y)
CUB_load_data()
открывает доступ к файлам tfrecords и возвращает список изображений и меток numpy массивов.
def load_data():
(train_x, train_y), (test_x, test_y) = CUB_load_data()
SHUFFLE_BUFFER_SIZE = 500
BATCH_SIZE = 2
@tf.function
def _parse_function(img, label):
feature = {}
img = tf.cast(img, dtype=tf.float32)
img = img / 255.0
feature["img"] = img
feature["label"] = label
return feature
train_dataset_raw = tf.data.Dataset.from_tensor_slices(
(train_x, train_y)).map(_parse_function)
test_dataset_raw = tf.data.Dataset.from_tensor_slices(
(test_x, test_y)).map(_parse_function)
train_dataset = train_dataset_raw.shuffle(SHUFFLE_BUFFER_SIZE).batch(
BATCH_SIZE)
test_dataset = test_dataset_raw.shuffle(SHUFFLE_BUFFER_SIZE).batch(
BATCH_SIZE)
return train_dataset, test_dataset
It выдает Allocation of 3609059328 exceeds 10% of system memory.
предупреждение, и оно застревает. Я попытался установить batch_size на очень низкое число, например 2, но оно все равно не работает с тем же предупреждением и застряло. Есть ли лучший способ загрузки файла tfrecords в наборы данных tf? Что не так с моим кодом?