Проблема с разделением данных из наборов данных Tensorflow - PullRequest
0 голосов
/ 12 марта 2020

Я пытаюсь загрузить данные из набора Oxford Flowers 102 и разделить их на обучающие, проверочные и тестовые наборы с помощью API tfds. Вот мой код:

# Split numbers 
train_split = 60
test_val_split = 20

splits = tfds.Split.ALL.subsplit([train_split,test_val_split, test_val_split])

# TODO: Create a training set, a validation set and a test set.
(training_set, validation_set, test_set), dataset_info = tfds.load('oxford_flowers102', split=splits, as_supervised=True, with_info=True)

Проблема в том, что когда я распечатываю dataset_info я получаю следующие цифры для своих наборов для тестирования, обучения и проверки

total_num_examples=8189,
splits={
    'test': 6149,
    'train': 1020,
    'validation': 1020,
},

Вопрос: Как мне получить данные для разделения на 6149 в обучающем наборе и 1020 в тестовых и проверочных наборах?

1 Ответ

1 голос
/ 14 марта 2020

Кажется, это ошибка в самом наборе данных. Тем более, что общий размер набора данных составляет 8189, а 6149 - это не 60% от общего объема, а 75%, так что вы вообще не выполняли никакого разделения. Они, вероятно, неправильно назвали сплиты. Кроме того, даже когда я пытаюсь загрузить набор данных различными способами, описанными здесь (https://github.com/tensorflow/datasets/blob/master/docs/splits.md), я получаю такое же неправильное разбиение.

Простым решением было бы просто передать модели набор тестов в качестве обучающего набора и наоборот, но у вас не будет желаемого процента. В противном случае вы можете загрузить весь набор данных (поезд + тест + проверка), а затем разделить его самостоятельно.

df_all, summary = tfds.load('oxford_flowers102', split='train+test+validation', with_info=True)

# check if the dataset loaded truly contains everything
df_all_length = [i for i,_ in enumerate(df_all)][-1] + 1

print(df_all_length)
>>out: 8189  # length is fine


train_size = int(0.6 * df_all_length)
val_test_size = int(0.2 * df_all_length)

# split whole dataset 
df_train = df_all.take(train_size)
df_test = df_all.skip(train_size)
df_valid = df_test.skip(val_test_size)
df_test = df_test.take(val_test_size)

df_train_length = [i for i,_ in enumerate(df_train)][-1] + 1
df_val_length = [i for i,_ in enumerate(df_val)][-1] + 1
df_test_length = [i for i,_ in enumerate(df_test)][-1] + 1

# check sizes
print('Train: ', df_train_length)
print('Validation :', df_valid_length)
print('Test :', df_test_length)

>>out: 4913 #(true 60% of 8189)
>>out: 1638 #(true 20% of 8189)
>>out: 1638
...