Соответствие цели именам цели в fetch_20newsgroups - PullRequest
3 голосов
/ 16 апреля 2020

Возможно, это глупый вопрос, но я не могу найти способ сопоставить метку цели в fetch_20newsgroups с именем цели. Это что-то столь же очевидное, как alt.atheism == 1, и поэтому я нигде не могу его найти, или есть метод, который мне не подходит?

>>> from sklearn.datasets import fetch_20newsgroups
>>> newsgroups_train = fetch_20newsgroups(subset='train')

>>> from pprint import pprint
>>> pprint(list(newsgroups_train.target_names))
['alt.atheism',
 'comp.graphics',
 'comp.os.ms-windows.misc',
 'comp.sys.ibm.pc.hardware',
 'comp.sys.mac.hardware',
 'comp.windows.x',
 'misc.forsale',
 'rec.autos',
 'rec.motorcycles',
 'rec.sport.baseball',
 'rec.sport.hockey',
 'sci.crypt',
 'sci.electronics',
 'sci.med',
 'sci.space',
 'soc.religion.christian',
 'talk.politics.guns',
 'talk.politics.mideast',
 'talk.politics.misc',
 'talk.religion.misc']
>>> newsgroups_train.target[:10]
array([12,  6,  9,  8,  6,  7,  9,  2, 13, 19])

1 Ответ

2 голосов
/ 16 апреля 2020

Определенно не глупый вопрос, так как я также не смог найти никаких документов по этому вопросу.

Я взглянул на исходный код функции fetch_20newsgroups из здесь .

def fetch_20newsgroups(data_home=None, subset='train', categories=None,  # line#-149
                       shuffle=True, random_state=42,
                       remove=(),
                       download_if_missing=True, return_X_y=False):
    """Load the filenames and data from the 20 newsgroups dataset \
(classification).
    Download it if necessary.
...
...
    categories : None or collection of string or unicode                 # line#-177
        If None (default), load all the categories.
        If not None, list of category names to load (other categories
        ignored).
...
...
    """
...
...
    if categories is not None:                                           # line#-287
        labels = [(data.target_names.index(cat), cat) for cat in categories]
        # Sort the categories to have the ordering of the labels
        labels.sort()
        labels, categories = zip(*labels)
        mask = np.in1d(data.target, labels)
        data.filenames = data.filenames[mask]
        data.target = data.target[mask]                                  # line#-294
        # searchsorted to have continuous labels
        data.target = np.searchsorted(labels, data.target)
        data.target_names = list(categories)
        # Use an object array to shuffle: avoids memory copy
        data_lst = np.array(data.data, dtype=object)
        data_lst = data_lst[mask]
        data.data = data_lst.tolist()
...
...
    return data

Обратите внимание, что одним из параметров является categories И из строки документации,

Если нет (по умолчанию), загрузить все категории.
Если нет, нет , список имен категорий для загрузки

Таким образом, по умолчанию categories со всеми target_names.

Теперь давайте go к строке # -287 исходного кода.
Вы можете видеть, что когда задано categories, оно сортируется на основе индекса каждого category от target_names.

А позже в строке # -294 target фильтруется на основе этих индексов.
Что говорит нам о том, что те числа, которые вы получаете из target, на самом деле
являются индексами категорий из target_names.

Поэтому вы можете сопоставить каждого из них по индексу из target_names.

for idx, cat in enumerate(newsgroups_train.target_names):
    print(idx, cat)
0 alt.atheism
1 comp.graphics
2 comp.os.ms-windows.misc
3 comp.sys.ibm.pc.hardware
4 comp.sys.mac.hardware
5 comp.windows.x
6 misc.forsale
7 rec.autos
8 rec.motorcycles
9 rec.sport.baseball
10 rec.sport.hockey
11 sci.crypt
12 sci.electronics
13 sci.med
14 sci.space
15 soc.religion.christian
16 talk.politics.guns
17 talk.politics.mideast
18 talk.politics.misc
19 talk.religion.misc
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...