Использование категориальных переменных в statsmodels OLS class - PullRequest
1 голос
/ 18 апреля 2019

Я хочу использовать statsmodels класс OLS для создания модели множественной регрессии. Рассмотрим следующий набор данных:

import statsmodels.api as sm
import pandas as pd
import numpy as np

dict = {'industry': ['mining', 'transportation', 'hospitality', 'finance', 'entertainment'],
  'debt_ratio':np.random.randn(5), 'cash_flow':np.random.randn(5) + 90} 

df = pd.DataFrame.from_dict(dict)

x = data[['debt_ratio', 'industry']]
y = data['cash_flow']

def reg_sm(x, y):
    x = np.array(x).T
    x = sm.add_constant(x)
    results = sm.OLS(endog = y, exog = x).fit()
    return results

Когда я запускаю следующий код:

reg_sm(x, y)

Я получаю следующую ошибку:

TypeError: '>=' not supported between instances of 'float' and 'str'

Я пытался преобразовать переменную industry в категориальную, но все равно получаю ошибку. У меня нет выбора.

1 Ответ

0 голосов
/ 18 апреля 2019

Вы на правильном пути с преобразованием в Категориальный dtype.Однако после преобразования DataFrame в массив NumPy вы получите object dtype (массивы NumPy представляют собой один единый тип в целом).Это означает, что отдельные значения по-прежнему лежат в основе str, что определенно не понравится регрессии.

Что вы можете сделать, это dummify этой функции.Вместо факторизации it, которая будет эффективно обрабатывать переменную как непрерывную, вы хотите сохранить некоторое подобие категоризации:

>>> import statsmodels.api as sm
>>> import pandas as pd
>>> import numpy as np
>>> np.random.seed(444)
>>> data = {
...     'industry': ['mining', 'transportation', 'hospitality', 'finance', 'entertainment'],
...    'debt_ratio':np.random.randn(5),
...    'cash_flow':np.random.randn(5) + 90
... }
>>> data = pd.DataFrame.from_dict(data)
>>> data = pd.concat((
...     data,
...     pd.get_dummies(data['industry'], drop_first=True)), axis=1)
>>> # You could also use data.drop('industry', axis=1)
>>> # in the call to pd.concat()
>>> data
         industry  debt_ratio  cash_flow  finance  hospitality  mining  transportation
0          mining    0.357440  88.856850        0            0       1               0
1  transportation    0.377538  89.457560        0            0       0               1
2     hospitality    1.382338  89.451292        0            1       0               0
3         finance    1.175549  90.208520        1            0       0               0
4   entertainment   -0.939276  90.212690        0            0       0               0

Теперь у вас есть dtypes, с которыми statsmodels может лучше работать.Цель drop_first состоит в том, чтобы избежать фиктивной ловушки :

>>> y = data['cash_flow']
>>> x = data.drop(['cash_flow', 'industry'], axis=1)
>>> sm.OLS(y, x).fit()
<statsmodels.regression.linear_model.RegressionResultsWrapper object at 0x115b87cf8>

И наконец, просто маленький указатель: это помогает избежать именования ссылок с именами, которые встроены в тенитипы объектов, такие как dict.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...