В TensorFlow 2.0 как узнать количество элементов в наборе данных? - PullRequest
1 голос
/ 30 мая 2019

Когда я загружаю набор данных, мне интересно, есть ли какой-нибудь быстрый способ найти количество выборок или партий в этом наборе данных.Я знаю, что если я загружаю набор данных с with_info=True, я могу видеть, например, total_num_examples=6000,, но эта информация недоступна, если я разделю набор данных.

В настоящее время я считаю количество образцов следующим образом, но думаю, есть ли лучшее решение:

train_subsplit_1, train_subsplit_2, train_subsplit_3 = tfds.Split.TRAIN.subsplit(3)

cifar10_trainsub3 = tfds.load("cifar10", split=train_subsplit_3)

cifar10_trainsub3 = cifar10_trainsub3.batch(1000)

n = 0
for i, batch in enumerate(cifar10_trainsub3.take(-1)):
    print(i, n, batch['image'].shape)
    n += len(batch['image'])

print(i, n)

1 Ответ

1 голос
/ 30 мая 2019

Если возможно узнать длину, вы можете использовать:

tf.data.experimental.cardinality(dataset)

но проблема в том, что набор данных TF по своей природе лениво загружается. Таким образом, мы можем не знать размер набора данных заранее. Действительно, вполне возможно, что набор данных представляет бесконечный набор данных!

Если это достаточно маленький набор данных, вы можете просто перебрать его, чтобы получить длину. Ранее я использовал следующую некрасивую небольшую конструкцию, но это зависит от того, насколько набор данных достаточно мал, чтобы мы были счастливы загрузить его в память, и это действительно не улучшение по сравнению с вашим циклом for выше!

dataset_length = [i for i,_ in enumerate(dataset)][-1] + 1
...