Используя ответ от здесь , ссылку на который я предоставил вам в комментарии выше, я мог бы отфильтровать набор данных, чтобы он включал метки 0, 1 и 2 только следующим образом:
import tensorflow_datasets as tfds
import tensorflow as tf
def predicate(x, allowed_labels=tf.constant([0., 1., 2.])):
label = x['label']
isallowed = tf.equal(allowed_labels, tf.cast(label, tf.float32))
reduced = tf.reduce_sum(tf.cast(isallowed, tf.float32))
return tf.greater(reduced, tf.constant(0.))
cifar100_builder = tfds.builder("cifar100")
cifar100_builder.download_and_prepare()
ds_train = cifar100_builder.as_dataset(split="train")
ds_test = cifar100_builder.as_dataset(split="test")
filtered_ds_train=ds_train.filter(predicate)
filtered_ds_test=ds_test.filter(predicate)
Теперь итерируя и печатая метки для Filter_ds_train , мы видим, что выбраны только 3 метки.
for x in myclasses:
print(x['label'])
Вы можете изменить allow_labels = tf.constant ([0., 1., 2.]) аргумент для включения других меток классов. В настоящее время он выбирает метки 0, 1 и 2.