GroupShuffleSplit в Scikit - учиться не группировать по указанным группам правильно - PullRequest
1 голос
/ 26 сентября 2019

У меня есть следующие данные:

    pd.DataFrame({'Group_ID':[1,1,1,2,2,2,3,4,5,5],
              'Year':[2011,2012,2013,2011,2012,2013,2011,2011,2011,2012],
              'Feature': [163,173,183,25,53,78,82,36,60,71]})

      Group_ID    Year   Feature_x
    0   1         2011    163
    1   1         2012    173
    2   1         2013    183
    3   2         2011    25
    4   2         2012    53
    5   2         2013    78
    6   3         2011    82
    7   4         2011    36
    8   5         2011    60
    9   5         2012    71

Мне нужно разделить набор данных на обучающий и тестовый набор на основе «Group_ID», чтобы 80% данных были помещены в обучающий набор и 20% в тестовом наборе.

То есть мне нужно, чтобы мой тренировочный набор выглядел примерно так:

Group_ID Year   Feature_x
    1    2011    163
    1    2012    173
    1    2013    183
    2    2011    25
    2    2012    53
    2    2013    78
    3    2011    82
    4    2011    36

И тестовый набор:

   Group_ID Year    Feature
    5       2011      60
    5       2012      71

(Это просто обобщенное представление моегонабор данных, поскольку он охватывает> 10000 строк и имеет> 300 объектов.)


Я знаю, что функция sklearn.model_selection.GroupShuffleSplit() может этого добиться, но когда я пытаюсь ее использовать, она не работает, как яожидать.Ниже приведен пример кода того, как я использовал GroupShuffleSplit.

from sklearn.model_selection import GroupShuffleSplit

gss = GroupShuffleSplit(n_splits=1, train_size=0.8)

for train_indices, test_indices in gss.split(df, groups = df.Group_ID):
    print('Train indices:', train_indices)
    print('Test indices:', test_indices)

Выходы:

Train indices: [0, 1, 3, 4, 5, 6, 8]
Test indices: [2, 7, 9]

Или в виде таблицы:

Поезд:

      Group_ID    Year   Feature_x
    0   1         2011    163
    1   1         2012    173
    3   2         2011    25
    4   2         2012    53
    5   2         2013    78
    6   3         2011    82
    8   5         2011    60

Тест:

      Group_ID    Year   Feature_x
    2   1         2013    183
    7   4         2011    36
    9   5         2012    71

Как видите, есть несколько групп, которые содержатся как в поезде, так и в тестовом наборе.Я проверил, что тип данных Group_ID одинаков во всем фрейме данных, поэтому я не совсем уверен, что вызывает эту проблему.


edit:

Я использовал df.loc[train_indices] и df.loc[test_indices], чтобы посмотреть значения строк обучения и тестирования.В строках поезда все значения, которые должны быть Ints, являются числами с плавающей запятой, а строки тестов нормальные.

train_df = df.loc[train_indices]
test_df = df.loc[test_indices]


train_df.head()
    Year    Group_ID    Feature_x
0   2010.0  14.0    1.445161
1   2010.0  15.0    1.445161
4   2010.0  18.0    1.445161
5   2010.0  20.0    0.835484
6   2010.0  21.0    0.835484


test_df.head()
    Year    Group_ID    Feature_x
2   2010    16  1.445161
3   2010    17  1.445161
7   2010    25  0.835484
12  2010    33  0.835484
15  2010    37  1.009677

Кроме того, использование метода info() показывает, что эти два имеют разные типы столбцов, несмотря наисходящий из того же кадра данных.

train_df.head().info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 5 entries, 0 to 6
Data columns (total 4 columns):
Year               5 non-null float64
Group_ID           5 non-null float64
Feature_x          5 non-null float64
dtypes: float64(3)
memory usage: 200.0 bytes


test_df.head().info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 5 entries, 2 to 15
Data columns (total 4 columns):
Year               5 non-null int64
Group_ID           5 non-null int64
Feature_x          5 non-null float64
dtypes: float64(1), int64(2)
memory usage: 200.0 bytes

Затем, если я посмотрю на df.dtypes(), я увижу, что типы данных этого кадра соответствуют ожидаемым.

df.dtypes()

Year                 int64
Group_ID             int64
Feature_x            float64
                     ...   
Length: 387, dtype: object

IТакже хочу добавить, что сначала вошел входной фрейм данных, в котором большинство столбцов были «объектами», а затем я преобразовал их в Ints и Floats.Я не думаю, что это должно на что-то повлиять, но я мог бы также заявить об этом.

...