Tensorflow `from_tensor_slices` занимает слишком много времени, чтобы завершить sh и выбросить память превысила 10% предупреждения системной памяти - PullRequest
0 голосов
/ 06 января 2020

Мой способ загрузки данных tfrecord в набор данных tf:

def CUB_load_data():

    ds = tfds.load('caltech_birds2011', download=False, data_dir='../../datasets/')
    train_data = ds['train']
    test_data = ds['test']


    train_x = []
    train_y = []

    test_x = []
    test_y = []

    for i in train_data.__iter__():
        resized = cv2.resize(i['image'].numpy(), dsize=(224,224))
        train_x.append(resized)
        train_y.append(i['label'])

    for i in test_data.__iter__():
        resized = cv2.resize(i['image'].numpy(), dsize=(224,224))
        test_x.append(resized)
        test_y.append(i['label'])
    return (train_x, train_y), (test_x, test_y)

CUB_load_data() открывает доступ к файлам tfrecords и возвращает список изображений и меток numpy массивов.

def load_data():
    (train_x, train_y), (test_x, test_y) =  CUB_load_data()
    SHUFFLE_BUFFER_SIZE = 500
    BATCH_SIZE = 2
    @tf.function
    def _parse_function(img, label):
        feature = {}
        img = tf.cast(img, dtype=tf.float32)
        img = img / 255.0
        feature["img"] = img
        feature["label"] = label
        return feature

    train_dataset_raw = tf.data.Dataset.from_tensor_slices(
        (train_x, train_y)).map(_parse_function)
    test_dataset_raw = tf.data.Dataset.from_tensor_slices(
        (test_x, test_y)).map(_parse_function)
    train_dataset = train_dataset_raw.shuffle(SHUFFLE_BUFFER_SIZE).batch(
        BATCH_SIZE)
    test_dataset = test_dataset_raw.shuffle(SHUFFLE_BUFFER_SIZE).batch(
        BATCH_SIZE)
    return train_dataset, test_dataset

It выдает Allocation of 3609059328 exceeds 10% of system memory. предупреждение, и оно застревает. Я попытался установить batch_size на очень низкое число, например 2, но оно все равно не работает с тем же предупреждением и застряло. Есть ли лучший способ загрузки файла tfrecords в наборы данных tf? Что не так с моим кодом?

...