Фильтруйте набор данных, чтобы получить только изображения определенного класса - PullRequest
1 голос
/ 17 апреля 2019

Я хочу подготовить набор данных omniglot для обучения n-shot.Поэтому мне нужно 5 образцов из 10 классов (алфавит)

Код для воспроизведения

import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np

builder = tfds.builder("omniglot")
# assert builder.info.splits['train'].num_examples == 60000
builder.download_and_prepare()
# Load data from disk as tf.data.Datasets
datasets = builder.as_dataset()
dataset, test_dataset = datasets['train'], datasets['test']


def resize(example):
    image = example['image']
    image = tf.image.resize(image, [28, 28])
    image = tf.image.rgb_to_grayscale(image, )
    image = image / 255
    one_hot_label = np.zeros((51, 10))
    return image, one_hot_label, example['alphabet']


def stack(image, label, alphabet):
    return (image, label), label[-1]

def filter_func(image, label, alphabet):
    # get just images from alphabet in array, not just 2
    arr = np.array(2,3,4,5)
    result = tf.reshape(tf.equal(alphabet, 2 ), [])
    return result

# correct size
dataset = dataset.map(resize)
# now filter the dataset for the batch
dataset = dataset.filter(filter_func)
# infinite stream of batches (classes*samples + 1)
dataset = dataset.repeat().shuffle(1024).batch(51)
# stack the images together
dataset = dataset.map(stack)
dataset = dataset.shuffle(buffer_size=1000)
dataset = dataset.batch(32)

for i, (image, label) in enumerate(tfds.as_numpy(dataset)):
    print(i, image[0].shape)

Теперь я хочу отфильтровать изображения в наборе данных с помощью функции фильтра.tf.equal, просто позвольте мне фильтровать по одному классу, я хочу что-то вроде тензорного в массиве.

Вы видите способ сделать это с помощью функции фильтра?Или это неправильный путь, и есть гораздо более простой способ?

Я хочу создать серию из 51 изображения и соответствующих меток, которые относятся к одному и тому же классу N = 10.Из каждого класса мне нужно K = 5 разных изображений и дополнительное (которое мне нужно классифицировать).Каждая партия из N * K + 1 (51) изображений должна быть из 10 новых случайных классов.

Заранее большое спасибо.

1 Ответ

2 голосов
/ 17 апреля 2019

tf.equal() поддерживает вещание и позволяет сравнивать скаляры с тензорами rank > 0.

Чтобы сохранить только определенные метки, используйте этот предикат:

dataset = datasets['train']

def predicate(x, allowed_labels=tf.constant([0., 1., 2.])):
    label = x['label']
    isallowed = tf.equal(allowed_labels, tf.cast(label, tf.float32))
    reduced = tf.reduce_sum(tf.cast(isallowed, tf.float32))
    return tf.greater(reduced, tf.constant(0.))

dataset = dataset.filter(predicate).batch(20)

for i, x in enumerate(tfds.as_numpy(dataset)):
    print(x['label'])
# [1 0 0 1 2 1 1 2 1 0 0 1 2 0 1 0 2 2 0 1]
# [1 0 2 2 0 2 1 2 1 2 2 2 0 2 0 2 1 2 1 1]
# [2 1 2 1 0 1 1 0 1 2 2 0 2 0 1 0 0 0 0 0]

allowed_labels указывает метки, которые вы хотите сохранить. Все метки, не входящие в этот тензор, будут отфильтрованы.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...