Почему StratifiedShuffleSplit возвращает индексы обучения / тестирования полного набора данных при вызове next ()? - PullRequest
0 голосов
/ 23 сентября 2019

Я пытаюсь получить стратифицированную подвыборку моих данных, поскольку набор данных довольно большой (+ - 100 тыс. Изображений).Я пытался быть умным, используя StratifiedShuffleSplit класс scikit-learn. Документация предоставляет мне следующий пример:

import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit
X = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [1, 2], [3, 4]])
y = np.array([0, 0, 0, 1, 1, 1])
sss = StratifiedShuffleSplit(n_splits=5, test_size=0.5, random_state=0)

for train_index, test_index in sss.split(X, y):
    print("TRAIN:", train_index, "TEST:", test_index)
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]

, который дает следующий вывод (индексы соответствующих разделов Поезд / тест):

TRAIN: [5 2 3] TEST: [4 1 0]
TRAIN: [5 1 4] TEST: [0 2 3]
TRAIN: [5 0 2] TEST: [4 3 1]
TRAIN: [4 1 0] TEST: [2 3 5]
TRAIN: [0 5 1] TEST: [3 4 2]

На основевыше, и так как тип StratifiedShuffleSplit является генератором, я ожидал, что следующий код (при вызове next() даст мне один из созданных разбиений.

sss = StratifiedKFold(n_splits=10, random_state=0) 
train_index, test_index = next(sss.split(X, y)) #I expected this call to next would give me the indices of ONE of the (in this case 10) splits                         
print(type(sss.split(X,y)))                         #Type is generator

Однако, когда я проверяю len () после этого, я вижу, что я на самом деле получаю полный набор данных! Может кто-нибудь объяснить мне, почему это происходит, и как я могу достичь своей цели - взять стратифицированный подвыборку?

y_complete = np.concatenate((y[train_index], y[test_index]))            
X_complete = np.concatenate((X[train_index], X[test_index]))             
print(len(y_complete), len(X_complete)) #Gives me full length of dataset (So 99289 instead of expected 9920)

1 Ответ

1 голос
/ 24 сентября 2019

Это ожидаемое поведение, которое вы видите в примере, который вы создали.Если вы посмотрите только на train_index и test_index по отдельности, то увидите, что внутри них есть взаимоисключающий набор индексов.Однако, если вы посмотрите на объединенный набор индексов внутри train_index + test_index, , то объединенный набор индексов будет представлять собой сам полный набор данных. См. Приведенный ниже код для большей ясности:

ss = StratifiedKFold(n_splits=10, random_state=0)
split_gen = sss.split(X, y) # Store this generator in a variable 
train_index, test_index = next(split_gen)                         
print(type(sss.split(X,y)))

print("Length of Training split is {}".format(len(y[train_index])))
print("Indices are {}".format(train_index))
print("Actual data at those indices is {}".format(y[train_index]))

# Output : 
# Length of Training split is 3
# Indices are [5 2 3]
# Actual data at those indices is [1 0 1]

Обратите внимание, что train_index содержит только 3 индекса, а не полный набор данных.Подобное поведение можно увидеть для test_index:

print("Length of Test split is {}".format(len(y[test_index])))
print("Indices are {}".format(test_index))
print("Actual data at those indices is {}".format(y[test_index]))

# Output : 
# Length of Test split is 3
# Indices are [4 1 0]
# Actual data at those indices is [1 0 0]

Здесь вы можете видеть, что [5 2 3] и [4 1 0] являются взаимоисключающими, но объединяются, образуя полный набор данных, который происходит, когда вы используетеnp.concatenate выше.

Чтобы получить следующий сплит, вы используете next на объекте генератора:

train_index, test_index = next(split_gen)
print("Length of Set 2 Training split is {}".format(len(y[train_index])))
print("Indices are {}".format(train_index))
print("Actual data at those indices is {}".format(y[train_index]))

# Length of Set 2 Training split is 3
# Indices are [5 1 4]
# Actual data at those indices is [1 0 1]
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...