Ошибка при попытке настройки ближайших соседей с помощью hyperopt в python - PullRequest
0 голосов
/ 13 мая 2019

Я впервые пытаюсь настроить параметры KNeighbors с помощью hyperopt, но получаю странную ошибку.Не уверен, где проблема, но надеюсь на исправление.Вот более подробная информация по этому вопросу:

Код:

from sklearn.neighbors import KNeighborsRegressor
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials

space4knn = {
    'n_neighbors': hp.choice('n_neighbors', range(1,50)),
    'weights': hp.choice('weights',['uniform','distance']),
    'algorithm': hp.choice('algorithm',['auto', 'ball_tree', 'kd_tree', 'brute']),
    'leaf_size': hp.quniform('leaf_size',1,50,1),
    'metric': hp.choice('metric',['minkowski','mahalanobis','chebyshev','seuclidean']),
    'p': hp.quniform('p',1,15,1),
    'V': hp.quniform('V',1,15,1)
}

def score(params):
    print("Training with params: ")
    print(params)

    knn = KNeighborsRegressor(params)

    knn.fit(X_train[X_train['date_block_num']<33].drop(tc+['ID','target'],axis = 1),\
            X_train[X_train['date_block_num']<33]['target'])

    y_pred = knn.predict(X_train[X_train['date_block_num']==33].drop(tc+['ID', 'target'],axis = 1))

    y_pred = np.where( y_pred > 20, 20, np.where(y_pred < 0, 0, y_pred))

    y = X_train[X_train['date_block_num']==33]['target']

    error = np.sqrt(mean_squared_error(np.where( y > 20, 20, np.where(y < 0, 0, y)), np.round(y_pred)))

    # TODO: Add the importance for the selected features
    print("\tScore {0}\n\n".format(1-error))

    return {'loss': error, 'status': STATUS_OK}

best = fmin(score, space4knn, algo=tpe.suggest, 
                # trials=trials, 
                max_evals=100)

print("The best hyperparameters are: ", "\n")
print(best)

Ошибка:

--------------------------------------------------------------------------- TypeError                                 Traceback (most recent call last) <ipython-input-19-88db883141df> in <module>()
     34 best = fmin(score, space4knn, algo=tpe.suggest, 
     35                 # trials=trials,
---> 36                 max_evals=100)
     37 
     38 print("The best hyperparameters are: ", "\n")

C:\Anaconda3\lib\site-packages\hyperopt\fmin.py in fmin(fn, space, algo, max_evals, trials, rstate, allow_trials_fmin, pass_expr_memo_ctrl, catch_eval_exceptions, verbose, return_argmin, points_to_evaluate, max_queue_len, show_progressbar)
    405                     show_progressbar=show_progressbar)
    406     rval.catch_eval_exceptions = catch_eval_exceptions
--> 407     rval.exhaust()
    408     if return_argmin:
    409         return trials.argmin

C:\Anaconda3\lib\site-packages\hyperopt\fmin.py in exhaust(self)
    260     def exhaust(self):
    261         n_done = len(self.trials)
--> 262         self.run(self.max_evals - n_done, block_until_done=self.asynchronous)
    263         self.trials.refresh()
    264         return self

C:\Anaconda3\lib\site-packages\hyperopt\fmin.py in run(self, N, block_until_done)
    225                     else:
    226                         # -- loop over trials and do the jobs directly
--> 227                         self.serial_evaluate()
    228 
    229                     try:

C:\Anaconda3\lib\site-packages\hyperopt\fmin.py in serial_evaluate(self, N)
    139                 ctrl = base.Ctrl(self.trials, current_trial=trial)
    140                 try:
--> 141                     result = self.domain.evaluate(spec, ctrl)
    142                 except Exception as e:
    143                     logger.info('job exception: %s' % str(e))

C:\Anaconda3\lib\site-packages\hyperopt\base.py in evaluate(self, config, ctrl, attach_attachments)
    842                 memo=memo,
    843                 print_node_on_error=self.rec_eval_print_node_on_error)
--> 844             rval = self.fn(pyll_rval)
    845 
    846         if isinstance(rval, (float, int, np.number)):

<ipython-input-19-88db883141df> in score(params)
     17     knn = KNeighborsRegressor(params)
     18 
---> 19     knn.fit(X_train[X_train['date_block_num']<33].drop(tc+['ID','target'],axis
= 1),            X_train[X_train['date_block_num']<33]['target'])
     20 
     21     y_pred = knn.predict(X_train[X_train['date_block_num']==33].drop(tc+['ID', 'target'],axis = 1))

~\AppData\Roaming\Python\Python36\site-packages\sklearn\neighbors\base.py in fit(self, X, y)
    871             X, y = check_X_y(X, y, "csr", multi_output=True)
    872         self._y = y
--> 873         return self._fit(X)
    874 
    875 

~\AppData\Roaming\Python\Python36\site-packages\sklearn\neighbors\base.py in _fit(self, X)
    237             # and KDTree is generally faster when available
    238             if ((self.n_neighbors is None or
--> 239                  self.n_neighbors < self._fit_X.shape[0] // 2) and
    240                     self.metric != 'precomputed'):
    241                 if self.effective_metric_ in VALID_METRICS['kd_tree']:

TypeError: '<' not supported between instances of 'dict' and 'int'

Часть данных:

<div>
<style scoped="">
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }

    .dataframe tbody tr th {
        vertical-align: top;
    }

    .dataframe thead th {
        text-align: right;
    }
</style>
<table class="dataframe" border="1">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>shop_avg_item_price_per_category</th>
      <th>shop_sum_item_cnt_day_lag_1</th>
      <th>shop_avg_item_cnt_day_lag_1</th>
      <th>category_avg_item_price_lag_1</th>
      <th>shop_avg_item_price_per_category_lag_1</th>
      <th>avg_item_cnt_day_lag_2</th>
      <th>shop_sum_item_cnt_day_lag_2</th>
      <th>shop_avg_item_cnt_day_lag_2</th>
      <th>category_avg_item_price_lag_2</th>
      <th>shop_avg_item_price_per_category_lag_2</th>
      <th>shop_sum_item_cnt_per_category_lag_2</th>
      <th>shop_avg_item_cnt_per_category_lag_2</th>
      <th>item_price_lag_3</th>
      <th>avg_item_price_lag_3</th>
      <th>sum_item_cnt_day_lag_3</th>
      <th>avg_item_cnt_day_lag_3</th>
      <th>shop_sum_item_cnt_day_lag_3</th>
      <th>shop_avg_item_cnt_day_lag_3</th>
      <th>category_avg_item_price_lag_3</th>
      <th>category_sum_item_cnt_day_lag_3</th>
      <th>category_avg_item_cnt_day_lag_3</th>
      <th>shop_avg_item_price_per_category_lag_3</th>
      <th>shop_sum_item_cnt_per_category_lag_3</th>
      <th>shop_avg_item_cnt_per_category_lag_3</th>
      <th>item_price_lag_4</th>
      <th>...</th>
      <th>shop_avg_item_price_per_category_lag_4</th>
      <th>shop_sum_item_cnt_per_category_lag_4</th>
      <th>shop_avg_item_cnt_per_category_lag_4</th>
      <th>item_price_lag_6</th>
      <th>sum_item_cnt_day_lag_6</th>
      <th>avg_item_cnt_day_lag_6</th>
      <th>shop_sum_item_cnt_day_lag_6</th>
      <th>shop_avg_item_cnt_day_lag_6</th>
      <th>category_avg_item_price_lag_6</th>
      <th>category_sum_item_cnt_day_lag_6</th>
      <th>category_avg_item_cnt_day_lag_6</th>
      <th>shop_avg_item_price_per_category_lag_6</th>
      <th>shop_sum_item_cnt_per_category_lag_6</th>
      <th>shop_avg_item_cnt_per_category_lag_6</th>
      <th>target_lag_12</th>
      <th>sum_item_cnt_day_lag_12</th>
      <th>avg_item_cnt_day_lag_12</th>
      <th>shop_sum_item_cnt_day_lag_12</th>
      <th>shop_avg_item_cnt_day_lag_12</th>
      <th>category_avg_item_price_lag_12</th>
      <th>category_sum_item_cnt_day_lag_12</th>
      <th>category_avg_item_cnt_day_lag_12</th>
      <th>shop_avg_item_price_per_category_lag_12</th>
      <th>shop_sum_item_cnt_per_category_lag_12</th>
      <th>shop_avg_item_cnt_per_category_lag_12</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>4488710</th>
      <td>1255.026825</td>
      <td>1322.0</td>
      <td>0.156025</td>
      <td>1327.493573</td>
      <td>1326.642773</td>
      <td>0.044444</td>
      <td>862.0</td>
      <td>0.106564</td>
      <td>1394.281350</td>
      <td>1393.509145</td>
      <td>140.0</td>
      <td>0.571429</td>
      <td>1461.228571</td>
      <td>1393.537888</td>
      <td>6.0</td>
      <td>0.130435</td>
      <td>795.0</td>
      <td>0.098893</td>
      <td>1354.019638</td>
      <td>14113.0</td>
      <td>1.203154</td>
      <td>1350.446936</td>
      <td>191.0</td>
      <td>0.749020</td>
      <td>1461.228571</td>
      <td>...</td>
      <td>1252.440779</td>
      <td>257.0</td>
      <td>0.992278</td>
      <td>1461.228571</td>
      <td>3.0</td>
      <td>0.065217</td>
      <td>807.0</td>
      <td>0.096014</td>
      <td>1236.387840</td>
      <td>7497.0</td>
      <td>0.681917</td>
      <td>1235.285393</td>
      <td>138.0</td>
      <td>0.577406</td>
      <td>1.0</td>
      <td>7.0</td>
      <td>0.155556</td>
      <td>1146.0</td>
      <td>0.14122</td>
      <td>1201.195942</td>
      <td>8983.0</td>
      <td>0.849456</td>
      <td>1202.996944</td>
      <td>114.0</td>
      <td>0.485106</td>
    </tr>
    <tr>
      <th>4488711</th>
      <td>201.385813</td>
      <td>1322.0</td>
      <td>0.156025</td>
      <td>202.264922</td>
      <td>202.541912</td>
      <td>1.022222</td>
      <td>862.0</td>
      <td>0.106564</td>
      <td>196.136342</td>
      <td>195.771183</td>
      <td>49.0</td>
      <td>0.022940</td>
      <td>262.545138</td>
      <td>234.362681</td>
      <td>24.0</td>
      <td>0.521739</td>
      <td>795.0</td>
      <td>0.098893</td>
      <td>199.127951</td>
      <td>24173.0</td>
      <td>0.254233</td>
      <td>198.815861</td>
      <td>44.0</td>
      <td>0.021287</td>
      <td>262.545138</td>
      <td>...</td>
      <td>195.339467</td>
      <td>56.0</td>
      <td>0.025998</td>
      <td>262.545138</td>
      <td>41.0</td>
      <td>0.891304</td>
      <td>807.0</td>
      <td>0.096014</td>
      <td>191.044051</td>
      <td>24806.0</td>
      <td>0.228500</td>
      <td>190.384669</td>
      <td>92.0</td>
      <td>0.038983</td>
      <td>0.0</td>
      <td>0.0</td>
      <td>0.000000</td>
      <td>0.0</td>
      <td>0.00000</td>
      <td>199.562656</td>
      <td>0.0</td>
      <td>0.000000</td>
      <td>199.726151</td>
      <td>0.0</td>
      <td>0.000000</td>
    </tr>
    <tr>
      <th>4488712</th>
      <td>356.645697</td>
      <td>1322.0</td>
      <td>0.156025</td>
      <td>356.371940</td>
      <td>354.899145</td>
      <td>0.600000</td>
      <td>862.0</td>
      <td>0.106564</td>
      <td>341.460339</td>
      <td>338.698243</td>
      <td>42.0</td>
      <td>0.046823</td>
      <td>518.387881</td>
      <td>530.473848</td>
      <td>25.0</td>
      <td>0.543478</td>
      <td>795.0</td>
      <td>0.098893</td>
      <td>349.472332</td>
      <td>6950.0</td>
      <td>0.171495</td>
      <td>348.349523</td>
      <td>29.0</td>
      <td>0.032917</td>
      <td>518.387881</td>
      <td>...</td>
      <td>341.946769</td>
      <td>24.0</td>
      <td>0.028302</td>
      <td>518.387881</td>
      <td>14.0</td>
      <td>0.304348</td>
      <td>807.0</td>
      <td>0.096014</td>
      <td>336.559242</td>
      <td>8623.0</td>
      <td>0.224768</td>
      <td>333.367923</td>
      <td>21.0</td>
      <td>0.025180</td>
      <td>0.0</td>
      <td>0.0</td>
      <td>0.000000</td>
      <td>0.0</td>
      <td>0.00000</td>
      <td>352.298165</td>
      <td>0.0</td>
      <td>0.000000</td>
      <td>352.452030</td>
      <td>0.0</td>
      <td>0.000000</td>
    </tr>
    <tr>
      <th>4488713</th>
      <td>201.385813</td>
      <td>1322.0</td>
      <td>0.156025</td>
      <td>202.264922</td>
      <td>202.541912</td>
      <td>1.800000</td>
      <td>862.0</td>
      <td>0.106564</td>
      <td>196.136342</td>
      <td>195.771183</td>
      <td>49.0</td>
      <td>0.022940</td>
      <td>221.330209</td>
      <td>205.940671</td>
      <td>58.0</td>
      <td>1.260870</td>
      <td>795.0</td>
      <td>0.098893</td>
      <td>199.127951</td>
      <td>24173.0</td>
      <td>0.254233</td>
      <td>198.815861</td>
      <td>44.0</td>
      <td>0.021287</td>
      <td>221.330209</td>
      <td>...</td>
      <td>195.339467</td>
      <td>56.0</td>
      <td>0.025998</td>
      <td>221.330209</td>
      <td>87.0</td>
      <td>1.891304</td>
      <td>807.0</td>
      <td>0.096014</td>
      <td>191.044051</td>
      <td>24806.0</td>
      <td>0.228500</td>
      <td>190.384669</td>
      <td>92.0</td>
      <td>0.038983</td>
      <td>0.0</td>
      <td>299.0</td>
      <td>6.644444</td>
      <td>1146.0</td>
      <td>0.14122</td>
      <td>179.841439</td>
      <td>33489.0</td>
      <td>0.310860</td>
      <td>178.901545</td>
      <td>175.0</td>
      <td>0.073099</td>
    </tr>
    <tr>
      <th>4488714</th>
      <td>356.645697</td>
      <td>1322.0</td>
      <td>0.156025</td>
      <td>356.371940</td>
      <td>354.899145</td>
      <td>0.333333</td>
      <td>862.0</td>
      <td>0.106564</td>
      <td>341.460339</td>
      <td>338.698243</td>
      <td>42.0</td>
      <td>0.046823</td>
      <td>245.228381</td>
      <td>218.896182</td>
      <td>33.0</td>
      <td>0.717391</td>
      <td>795.0</td>
      <td>0.098893</td>
      <td>349.472332</td>
      <td>6950.0</td>
      <td>0.171495</td>
      <td>348.349523</td>
      <td>29.0</td>
      <td>0.032917</td>
      <td>245.228381</td>
      <td>...</td>
      <td>341.946769</td>
      <td>24.0</td>
      <td>0.028302</td>
      <td>245.228381</td>
      <td>46.0</td>
      <td>1.000000</td>
      <td>807.0</td>
      <td>0.096014</td>
      <td>336.559242</td>
      <td>8623.0</td>
      <td>0.224768</td>
      <td>333.367923</td>
      <td>21.0</td>
      <td>0.025180</td>
      <td>1.0</td>
      <td>61.0</td>
      <td>1.355556</td>
      <td>1146.0</td>
      <td>0.14122</td>
      <td>338.532581</td>
      <td>6094.0</td>
      <td>0.199738</td>
      <td>333.080366</td>
      <td>52.0</td>
      <td>0.076696</td>
    </tr>
  </tbody>
</table>
<p>5 rows × 58 columns</p>
</div>

Все типы данных являются float32.Это ошибка от KNeignoborsRegressor или от data или hyperopt и как ее исправить?Спасибо.

...