Метрики линейной регрессии с плохой производительностью - PullRequest
0 голосов
/ 15 сентября 2018

Я испытываю линейную регрессию на наборе данных из 637 наблюдений и 10 функций.это выглядит так100 строк

train_X:

array([[ 0.        ,  0.33178298,  0.26865475,  0.        , -0.72870151,
         0.        ,  1.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        , -1.59016455, -1.05162249,  0.        , -0.88503151,
         0.        ,  0.        ,  1.        ,  0.        ,  1.        ],
       [ 1.        , -0.50119857, -1.42884455,  0.        , -1.04136151,
         1.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  2.0321526 ,  0.77161751,  0.        , -0.74172901,
         0.        ,  1.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  1.80988656,  1.21170993,  0.        , -0.70264651,
         0.        ,  0.        ,  1.        ,  0.        ,  1.        ],
       [ 0.        , -0.76509649, -0.61153007,  0.        ,  1.40780853,
         0.        ,  0.        ,  0.        ,  1.        ,  1.        ],
       [ 0.        ,  1.45653173,  1.71467268,  1.        , -1.11952651,
         0.        ,  0.        ,  1.        ,  0.        ,  1.        ],
       [ 0.        ,  2.70376764,  1.58893199,  0.        , -0.5202615 ,
         0.        ,  1.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  0.384769  ,  1.71467268,  1.        ,  0.76946102,
         1.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        , -0.63022298, -1.4917149 ,  1.        , -0.272739  ,
         0.        ,  0.        ,  1.        ,  0.        ,  1.        ],
       [ 0.        ,  0.90568353,  0.64587682,  0.        , -0.97622401,
         0.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  0.20929581, -0.234308  ,  1.        , -1.08044401,
         0.        ,  0.        ,  0.        ,  1.        ,  1.        ],
       [ 1.        , -0.73550689, -0.48578938,  0.        , -0.74172901,
         1.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        , -0.94401033, -0.61153007,  1.        , -1.01530651,
         0.        ,  0.        ,  1.        ,  0.        ,  1.        ],
       [ 0.        ,  0.53615764, -0.42291904,  1.        ,  0.53496601,
         0.        ,  0.        ,  1.        ,  0.        ,  1.        ],
       [ 0.        , -0.28134099, -0.67440042,  1.        , -0.87200401,
         0.        ,  0.        ,  1.        ,  0.        ,  1.        ],
       [ 0.        , -0.53113224, -0.48578938,  1.        ,  1.68138603,
         0.        ,  0.        ,  0.        ,  1.        ,  1.        ],
       [ 0.        ,  0.34416932, -0.61153007,  0.        ,  1.52505603,
         0.        ,  0.        ,  0.        ,  1.        ,  1.        ],
       [ 1.        , -1.27362468, -1.68032593,  1.        ,  1.18634102,
         0.        ,  0.        ,  0.        ,  1.        ,  1.        ],
       [ 0.        ,  0.10951694, -0.234308  ,  0.        ,  1.87679853,
         0.        ,  0.        ,  1.        ,  0.        ,  1.        ],
       [ 0.        ,  1.28966016,  0.45726579,  1.        , -0.5463165 ,
         0.        ,  1.        ,  0.        ,  0.        ,  1.        ],
       [ 1.        , -1.67618082, -1.4917149 ,  1.        , -0.91108651,
         0.        ,  1.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  0.025221  , -0.54865973,  1.        ,  0.13111351,
         0.        ,  0.        ,  0.        ,  1.        ,  1.        ],
       [ 0.        ,  0.04001579, -0.61153007,  1.        , -0.72870151,
         1.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        , -0.99424383, -0.80014111,  0.        , -0.481179  ,
         0.        ,  1.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  0.70337327,  0.8973582 ,  0.        , -1.06741651,
         0.        ,  0.        ,  1.        ,  0.        ,  1.        ],
       [ 1.        , -1.39404746, -1.11449283,  0.        ,  1.65533103,
         0.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        , -0.86418723, -0.42291904,  1.        , -1.06741651,
         0.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        , -1.69751285, -1.24023352,  1.        ,  0.73037852,
         1.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        , -0.17949773, -0.73727076,  0.        , -0.5202615 ,
         0.        ,  0.        ,  1.        ,  0.        ,  1.        ],
       [ 0.        , -0.1571335 , -0.54865973,  0.        , -0.97622401,
         0.        ,  0.        ,  0.        ,  1.        ,  1.        ],
       [ 0.        , -0.42378393, -0.86301145,  1.        , -0.168519  ,
         0.        ,  1.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  0.4074773 ,  1.65180234,  1.        ,  1.10817602,
         0.        ,  0.        ,  0.        ,  1.        ,  1.        ],
       [ 0.        ,  0.37135047, -0.80014111,  0.        , -0.66356401,
         0.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  1.74795485,  0.64587682,  0.        , -0.012189  ,
         0.        ,  1.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  1.82020851,  1.02309889,  0.        , -0.376959  ,
         0.        ,  1.        ,  0.        ,  0.        ,  1.        ],
       [ 1.        , -0.97015928, -1.4917149 ,  0.        , -0.97622401,
         1.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  1.55390214,  1.40032096,  1.        , -0.0773265 ,
         0.        ,  0.        ,  1.        ,  0.        ,  1.        ],
       [ 0.        ,  0.38132835, -0.10856731,  1.        , -0.74172901,
         0.        ,  1.        ,  0.        ,  0.        ,  1.        ],
       [ 1.        , -0.68974624, -0.17143766,  1.        , -0.75475651,
         0.        ,  1.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  2.27093377,  1.33745062,  1.        ,  0.96487352,
         0.        ,  0.        ,  0.        ,  1.        ,  1.        ],
       [ 0.        , -1.43533526, -1.05162249,  0.        , -0.78081151,
         0.        ,  0.        ,  0.        ,  1.        ,  1.        ],
       [ 1.        , -1.85647092, -1.61745559,  0.        ,  0.53496601,
         0.        ,  0.        ,  1.        ,  0.        ,  1.        ],
       [ 0.        ,  0.48867666,  1.58893199,  1.        , -0.376959  ,
         0.        ,  1.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        , -0.31643563,  0.39439544,  1.        , -0.96319651,
         0.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  0.02350067,  0.8973582 ,  1.        ,  0.99092852,
         0.        ,  0.        ,  1.        ,  0.        ,  1.        ],
       [ 0.        , -0.10724406,  0.8973582 ,  0.        ,  0.95184602,
         0.        ,  1.        ,  0.        ,  0.        ,  1.        ],
       [ 1.        , -1.59016455, -0.73727076,  1.        ,  0.33955351,
         1.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  0.09437808,  0.58300648,  1.        , -0.5463165 ,
         1.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  2.3720889 ,  1.52606165,  0.        , -1.09347151,
         1.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        , -0.66875827, -0.234308  ,  0.        , -0.1554915 ,
         0.        ,  1.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        , -0.72828153, -0.29717835,  1.        , -0.84594901,
         0.        ,  1.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  2.67073739,  1.33745062,  0.        ,  0.83459852,
         1.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        , -0.36185222, -0.10856731,  0.        , -0.012189  ,
         0.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  2.10303001,  0.96022855,  0.        ,  1.70744103,
         1.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        , -0.05219365, -0.10856731,  1.        , -0.89805901,
         0.        ,  0.        ,  1.        ,  0.        ,  1.        ],
       [ 0.        , -0.90203439, -0.73727076,  1.        ,  1.19936852,
         0.        ,  0.        ,  0.        ,  1.        ,  1.        ],
       [ 0.        ,  0.62492643,  0.45726579,  1.        ,  1.82468853,
         1.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 1.        , -0.77266592, -0.48578938,  1.        , -0.5463165 ,
         0.        ,  0.        ,  1.        ,  0.        ,  1.        ],
       [ 0.        , -0.21183984, -0.54865973,  1.        , -0.5463165 ,
         0.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  0.55164057,  0.83448786,  0.        ,  1.98101854,
         0.        ,  1.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  1.1217564 ,  0.8973582 ,  0.        ,  0.33955351,
         0.        ,  0.        ,  0.        ,  1.        ,  1.        ],
       [ 0.        , -1.10640905, -1.11449283,  1.        , -1.02833401,
         0.        ,  0.        ,  1.        ,  0.        ,  1.        ],
       [ 1.        , -1.22201492, -1.36597421,  0.        , -0.533289  ,
         0.        ,  0.        ,  0.        ,  1.        ,  1.        ],
       [ 0.        , -0.38180799, -1.4917149 ,  1.        , -0.5984265 ,
         0.        ,  0.        ,  1.        ,  0.        ,  1.        ],
       [ 0.        ,  0.26778688,  0.08004372,  1.        ,  1.94193604,
         0.        ,  0.        ,  1.        ,  0.        ,  1.        ],
       [ 0.        , -0.14337089,  0.20578441,  1.        , -0.4160415 ,
         0.        ,  1.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  2.46154582,  1.27458027,  0.        ,  0.41771851,
         0.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 1.        , -1.30321427, -1.55458524,  1.        , -1.11952651,
         0.        ,  0.        ,  0.        ,  1.        ,  1.        ],
       [ 1.        , -0.78161161, -0.73727076,  0.        , -0.70264651,
         0.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  1.58762052,  0.70874717,  0.        , -0.298794  ,
         1.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        , -1.00181326, -0.86301145,  1.        , -0.83292151,
         1.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        , -0.38800116, -0.98875214,  1.        , -1.13255401,
         1.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  0.77872352,  1.21170993,  0.        , -0.93714151,
         0.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 1.        , -1.28807541, -0.36004869,  1.        ,  1.92890854,
         0.        ,  0.        ,  1.        ,  0.        ,  1.        ],
       [ 1.        , -0.50704768, -0.17143766,  0.        , -0.116409  ,
         0.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  0.42949746,  1.14883958,  0.        ,  0.82157102,
         0.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  0.46940901, -0.61153007,  1.        , -0.96319651,
         1.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  0.29187143,  0.52013613,  0.        , -1.04136151,
         0.        ,  1.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  1.92136365,  1.21170993,  0.        , -0.68961901,
         0.        ,  1.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        , -1.33624452, -1.24023352,  0.        , -0.87200401,
         0.        ,  0.        ,  1.        ,  0.        ,  1.        ],
       [ 1.        , -0.45474979, -0.73727076,  0.        ,  1.87679853,
         0.        ,  1.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  0.46803275,  1.65180234,  1.        , -1.02833401,
         0.        ,  1.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  0.39337063,  0.20578441,  0.        ,  1.53808353,
         0.        ,  0.        ,  1.        ,  0.        ,  1.        ],
       [ 0.        ,  0.49555796,  1.08596924,  1.        ,  1.14725852,
         0.        ,  0.        ,  1.        ,  0.        ,  1.        ],
       [ 0.        , -0.42860084, -0.9258818 ,  0.        , -0.74172901,
         0.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  0.34416932,  0.01717338,  0.        , -1.04136151,
         0.        ,  0.        ,  1.        ,  0.        ,  1.        ],
       [ 1.        , -0.5817098 , -1.42884455,  1.        ,  0.50891101,
         0.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 1.        , -1.40161689, -1.4917149 ,  0.        ,  0.09203101,
         1.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  0.34210493, -0.48578938,  0.        ,  2.09826604,
         0.        ,  0.        ,  1.        ,  0.        ,  1.        ],
       [ 0.        , -0.34946587,  0.26865475,  0.        ,  2.09826604,
         1.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  0.93527313,  0.3315251 ,  1.        , -1.09347151,
         0.        ,  0.        ,  1.        ,  0.        ,  1.        ],
       [ 0.        ,  0.49074105, -0.48578938,  0.        , -0.87200401,
         1.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        , -0.70454104, -0.98875214,  1.        ,  0.31349851,
         0.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  1.95542609,  1.58893199,  0.        , -0.5463165 ,
         0.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  0.56161845, -0.36004869,  1.        ,  1.59019353,
         1.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  0.93871378, -0.10856731,  1.        , -0.97622401,
         0.        ,  0.        ,  1.        ,  0.        ,  1.        ],
       [ 0.        , -0.10208308,  0.3315251 ,  0.        , -0.83292151,
         0.        ,  1.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        , -0.56140996, -1.36597421,  0.        ,  0.88670852,
         0.        ,  0.        ,  0.        ,  0.        ,  1.        ],
       [ 0.        ,  0.34416932, -0.61153007,  0.        ,  0.11808601,
         0.        ,  1.        ,  0.        ,  0.        ,  1.        ]])

Аналогично, 100 соответствующая переменная Y, которая является количеством транзакций, выглядит следующим образом

train_Y:

array([ 3, 11,  3,  4,  6, 11,  9,  2, 13,  4,  1, 10, 13,  8,  5,  8,  1,
       11,  1, 10,  6,  3,  8,  3,  8,  2, 20,  9, 13,  9,  3,  4,  8,  2,
        1,  3,  5,  6,  3,  1,  3,  9,  2, 14,  4,  9,  4, 13, 10,  2, 15,
        1,  2,  8,  4, 12, 14,  3, 13,  3, 11,  7, 10, 10,  6, 13,  5,  5,
       10,  9,  2,  4,  5,  3, 19,  5,  4,  2,  7,  5,  6,  1,  9,  1,  1,
        1, 12,  5, 11,  3,  5, 12,  2,  3,  3,  4,  3,  5,  2,  4])

Теперь я использую несколько таких алгоритмов линейной регрессии, используя Grid Search для настройки HP.

class TotalReservations(object):
    """ Class to carry out machine learning tasks """



    def __init__(self,train_X_std,train_Y,test_X_std,test_Y,estimator,param_grid):
        """ Constructor to initialze the data

        """

        self.train_X=train_X_std
        self.train_Y=train_Y
        self.test_X=test_X_std
        self.test_Y=test_Y
        # The hyper parameter grid for this estimator
        self.param_grid=param_grid
        # The actual ML estimator from sklearn
        self.estimator=estimator



    def grid_search(self):

            """This function does Cross Validation using Grid Search

            """

            from sklearn.model_selection import GridSearchCV
            self.g_cv = GridSearchCV(estimator=self.estimator,param_grid=self.param_grid,cv=5)
            self.g_cv.fit(self.train_X,self.train_Y)



        def train_performance(self):
            """ Method to get the training performance of the estimator across all parameters

            """

            self.mean_train_score=self.g_cv.cv_results_['mean_train_score']
            self.mean_val_score=self.g_cv.cv_results_['mean_test_score']

            # Printing the Dummy Regressor score
            self.dummy_regressor(self.train_X,self.train_Y,self.test_X,self.test_Y)

            print("\nThe mean train scores are {}\n".format(self.mean_train_score))

            print("The mean validation scores are {}\n".format(self.mean_val_score))

            print("The score on held out data is: {score}\n Hyper-Parameters for Best Score : {param}\n"\
            .format(score=self.g_cv.best_score_,param=self.g_cv.best_params_))

Я запускаю первую регрессию гребня со следующими значениями сетки параметров.

# For Ridge Regression
from sklearn.linear_model import Ridge
ridge_lr=Ridge(random_state=42)

#***************************************************************************************


    # Creating Hyperparameter values Grid for Random Forests
    # Actual to be used 


    # Creating Hyperparameters values Grid for Ridge Regression
    grid_ridge={'alpha':[0.0001,0.001,0.01,0.1,1.0]}

Это дает мне очень плохие значения R2 и MSE.Заметьте, что я провел 5-кратную перекрестную валидацию, поэтому у нас есть 5 наборов значений для оценки тренировок и валидации для каждого фолда и средних баллов по всем 5.

Оценка по умолчанию в Склеарне, по-моему, R2Гол.

td_ridge.grid_search()

# Training Performance of SGD Regressor
td_ridge.train_performance()


The mean train scores are [0.21945918 0.21945918 0.21945917 0.2194588  0.21942437]

The mean validation scores are [0.18215611 0.18215637 0.18215894 0.18218431 0.18240627]

The score on held out data is: 0.18240626691460574
 Hyper-Parameters for Best Score : {'alpha': 1.0}

The R2 Score of ridge_regressor on test data is: 0.17520249693363288

The mse of ridge_regressor on test data is: 17.791933297637534
The mean absolute error of ridge_regressor on test data is: 3.381908224949849

Я тоже пробовал Случайный лесной Регрессор, но результаты те же.

Есть идеи, что здесь происходит?Почему оценки так плохи?

...