Сокращение классов набора данных - PullRequest
1 голос
/ 16 марта 2020

Допустим, у меня инициализирован набор данных CIFAR-100 (images) следующим образом:

cifar100_builder = tfds.builder("cifar100")
cifar100_builder.download_and_prepare()
ds_train = cifar100_builder.as_dataset(split="train")
ds_test = cifar100_builder.as_dataset(split="test")

Например, ds_train является объектом типа:

<DatasetV1Adapter shapes: {coarse_label: (), image: (32, 32, 3), label: ()}, types: {coarse_label: tf.int64, image: tf.uint8, label: tf.int64}> which is a `tf.data.dataset`

Этот набор данных содержит 100 классов. Допустим, у меня также есть список с именем our_index, в котором есть 20 различных элементов, каждый из которых представляет один класс. Я хотел бы выполнить итерацию по набору данных ds_train и сохранить только те элементы, которые принадлежат одному из этих 20 классов. Чтобы сделать это, я думаю, что я мог бы использовать это: [https://www.tensorflow.org/api_docs/python/tf/data/Dataset#filter] [1] .

но я не уверен, как. Любые идеи?

1 Ответ

1 голос
/ 16 марта 2020

Используя ответ от здесь , ссылку на который я предоставил вам в комментарии выше, я мог бы отфильтровать набор данных, чтобы он включал метки 0, 1 и 2 только следующим образом:

import tensorflow_datasets as tfds
import tensorflow as tf

def predicate(x, allowed_labels=tf.constant([0., 1., 2.])):
    label = x['label']
    isallowed = tf.equal(allowed_labels, tf.cast(label, tf.float32))
    reduced = tf.reduce_sum(tf.cast(isallowed, tf.float32))
    return tf.greater(reduced, tf.constant(0.))

cifar100_builder = tfds.builder("cifar100")
cifar100_builder.download_and_prepare()
ds_train = cifar100_builder.as_dataset(split="train")
ds_test = cifar100_builder.as_dataset(split="test")

filtered_ds_train=ds_train.filter(predicate)
filtered_ds_test=ds_test.filter(predicate)

Теперь итерируя и печатая метки для Filter_ds_train , мы видим, что выбраны только 3 метки.

for x in myclasses:
  print(x['label'])

Вы можете изменить allow_labels = tf.constant ([0., 1., 2.]) аргумент для включения других меток классов. В настоящее время он выбирает метки 0, 1 и 2.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...