Класс Склеарна "Стратифицированный ШафлСплит" - PullRequest
1 голос
/ 10 января 2020

Меня немного смущает, как работает класс StratifiedShuffleSplit Склеарна.

Код, приведенный ниже, взят из книги Жерона "Практическое обучение в машине", глава 2, где он проводит стратифицированную выборку.

from sklearn.model_selection import StratifiedShuffleSplit

split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_index, test_index in split.split(housing, housing["income_cat"]):
    strat_train_set = housing.loc[train_index]
    strat_test_set = housing.loc[test_index]

Особенно, что происходит в split.split?

Спасибо!

Ответы [ 2 ]

1 голос
/ 10 января 2020

Функция split.split () возвращает индексы для образцов поезда и тестовых образцов. Он проверит его на наличие указанного количества перекрестной проверки и будет возвращать каждый раз индексы обучающих и тестовых выборок, используя которые можно создать набор обучающих и тестовых данных путем фильтрации всего набора данных.

1 голос
/ 10 января 2020

Поскольку вы не предоставили набор данных, я использую образец sklearn, чтобы ответить на этот вопрос.

Подготовить набор данных

# generate data
import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit
data = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [1, 2], [3, 4]])
group_label = np.array([0, 0, 0, 1, 1, 1])

Это создает набор данных data, который имеет 6 наблюдений и 2 переменные. group_label имеет 2 значения, означает group 0 и group 1. В этом случае group 0 содержит 3 сэмпла, это же group 1. В общем, размер группы не обязательно должен быть одинаковым.

Создать StratifiedShuffleSplit экземпляр объекта

sss = StratifiedShuffleSplit(n_splits=5, test_size=0.5, random_state=0)
sss.get_n_splits(data, group_label)

Out:

5

На этом шаге вы можете создать экземпляр StratifiedShuffleSplit, вы можете указать функции, как разделить (В random_state = 0, разделить данные 5 times, каждый раз, когда 50% данных будет разделить до test set). Однако данные разделяются только при вызове на следующем шаге.

Вызов экземпляра и разделение данных.

# the instance is actually a generater
type(sss.split(data, group_label))

# split data
for train_index, test_index in sss.split(data, group_label):
     print("n_split",,"TRAIN:", train_index, "TEST:", test_index)
     X_train, X_test = X[train_index], X[test_index]
     y_train, y_test = y[train_index], y[test_index]
type(sss.split(data, group_label))

out:

TRAIN: [5 2 3] TEST: [4 1 0]
TRAIN: [5 1 4] TEST: [0 2 3]
TRAIN: [5 0 2] TEST: [4 3 1]
TRAIN: [4 1 0] TEST: [2 3 5]
TRAIN: [0 5 1] TEST: [3 4 2]

На этом шаге spliter, который вы определили на последнем шаге, сгенерирует 5 разбиений data один за другим. Например, в первом разделении исходные данные перемешиваются, и выборка 5,2,3 выбирается в качестве набора поездов; во втором разделении данные снова перемешиваются, и выборка 5,1,4 выбирается в качестве набора поездов; и др c ..

...