Как показать распределение классов в объекте Dataset в Tensorflow - PullRequest
1 голос
/ 27 марта 2020

Я работаю над многоклассовой классификационной задачей, используя мои собственные изображения.

filenames = [] # a list of filenames
labels = [] # a list of labels corresponding to the filenames
full_ds = tf.data.Dataset.from_tensor_slices((filenames, labels))

Этот полный набор данных будет перетасован и разбит на обучающий, действительный и тестовый набор данных

full_ds_size = len(filenames)
full_ds = full_ds.shuffle(buffer_size=full_ds_size*2, seed=128) # seed is used for reproducibility

train_ds_size = int(0.64 * full_ds_size)
valid_ds_size = int(0.16 * full_ds_size)

train_ds = full_ds.take(train_ds_size)
remaining = full_ds.skip(train_ds_size)  
valid_ds = remaining.take(valid_ds_size)
test_ds = remaining.skip(valid_ds_size)

Сейчас я пытаюсь понять, как каждый класс распределяется в train_ds, valid_ds и test_ds. Уродливое решение состоит в том, чтобы перебрать все элементы в наборе данных и подсчитать вхождение каждого класса. Есть ли лучший способ решить это?

Мое безобразное решение:

def get_class_distribution(dataset):
    class_distribution = {}
    for element in dataset.as_numpy_iterator():
        label = element[1]

        if label in class_distribution.keys():
            class_distribution[label] += 1
        else:
            class_distribution[label] = 0

    # sort dict by key
    class_distribution = collections.OrderedDict(sorted(class_distribution.items())) 
    return class_distribution


train_ds_class_dist = get_class_distribution(train_ds)
valid_ds_class_dist = get_class_distribution(valid_ds)
test_ds_class_dist = get_class_distribution(test_ds)

print(train_ds_class_dist)
print(valid_ds_class_dist)
print(test_ds_class_dist)

1 Ответ

1 голос
/ 27 марта 2020

Ответ ниже предполагает:

  • существует пять классов.
  • метки являются целыми числами от 0 до 4.

Его можно изменить на в соответствии с вашими потребностями.

Определите функцию счетчика:

def count_class(counts, batch, num_classes=5):
    labels = batch['label']
    for i in range(num_classes):
        cc = tf.cast(labels == i, tf.int32)
        counts[i] += tf.reduce_sum(cc)
    return counts

Используйте операцию reduce:

initial_state = dict((i, 0) for i in range(5))
counts = train_ds.reduce(initial_state=initial_state,
                         reduce_func=count_class)

print([(k, v.numpy()) for k, v in counts.items()])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...