Я впервые пытаюсь настроить параметры KNeighbors с помощью hyperopt, но получаю странную ошибку.Не уверен, где проблема, но надеюсь на исправление.Вот более подробная информация по этому вопросу:
Код:
from sklearn.neighbors import KNeighborsRegressor
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
space4knn = {
'n_neighbors': hp.choice('n_neighbors', range(1,50)),
'weights': hp.choice('weights',['uniform','distance']),
'algorithm': hp.choice('algorithm',['auto', 'ball_tree', 'kd_tree', 'brute']),
'leaf_size': hp.quniform('leaf_size',1,50,1),
'metric': hp.choice('metric',['minkowski','mahalanobis','chebyshev','seuclidean']),
'p': hp.quniform('p',1,15,1),
'V': hp.quniform('V',1,15,1)
}
def score(params):
print("Training with params: ")
print(params)
knn = KNeighborsRegressor(params)
knn.fit(X_train[X_train['date_block_num']<33].drop(tc+['ID','target'],axis = 1),\
X_train[X_train['date_block_num']<33]['target'])
y_pred = knn.predict(X_train[X_train['date_block_num']==33].drop(tc+['ID', 'target'],axis = 1))
y_pred = np.where( y_pred > 20, 20, np.where(y_pred < 0, 0, y_pred))
y = X_train[X_train['date_block_num']==33]['target']
error = np.sqrt(mean_squared_error(np.where( y > 20, 20, np.where(y < 0, 0, y)), np.round(y_pred)))
# TODO: Add the importance for the selected features
print("\tScore {0}\n\n".format(1-error))
return {'loss': error, 'status': STATUS_OK}
best = fmin(score, space4knn, algo=tpe.suggest,
# trials=trials,
max_evals=100)
print("The best hyperparameters are: ", "\n")
print(best)
Ошибка:
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) <ipython-input-19-88db883141df> in <module>()
34 best = fmin(score, space4knn, algo=tpe.suggest,
35 # trials=trials,
---> 36 max_evals=100)
37
38 print("The best hyperparameters are: ", "\n")
C:\Anaconda3\lib\site-packages\hyperopt\fmin.py in fmin(fn, space, algo, max_evals, trials, rstate, allow_trials_fmin, pass_expr_memo_ctrl, catch_eval_exceptions, verbose, return_argmin, points_to_evaluate, max_queue_len, show_progressbar)
405 show_progressbar=show_progressbar)
406 rval.catch_eval_exceptions = catch_eval_exceptions
--> 407 rval.exhaust()
408 if return_argmin:
409 return trials.argmin
C:\Anaconda3\lib\site-packages\hyperopt\fmin.py in exhaust(self)
260 def exhaust(self):
261 n_done = len(self.trials)
--> 262 self.run(self.max_evals - n_done, block_until_done=self.asynchronous)
263 self.trials.refresh()
264 return self
C:\Anaconda3\lib\site-packages\hyperopt\fmin.py in run(self, N, block_until_done)
225 else:
226 # -- loop over trials and do the jobs directly
--> 227 self.serial_evaluate()
228
229 try:
C:\Anaconda3\lib\site-packages\hyperopt\fmin.py in serial_evaluate(self, N)
139 ctrl = base.Ctrl(self.trials, current_trial=trial)
140 try:
--> 141 result = self.domain.evaluate(spec, ctrl)
142 except Exception as e:
143 logger.info('job exception: %s' % str(e))
C:\Anaconda3\lib\site-packages\hyperopt\base.py in evaluate(self, config, ctrl, attach_attachments)
842 memo=memo,
843 print_node_on_error=self.rec_eval_print_node_on_error)
--> 844 rval = self.fn(pyll_rval)
845
846 if isinstance(rval, (float, int, np.number)):
<ipython-input-19-88db883141df> in score(params)
17 knn = KNeighborsRegressor(params)
18
---> 19 knn.fit(X_train[X_train['date_block_num']<33].drop(tc+['ID','target'],axis
= 1), X_train[X_train['date_block_num']<33]['target'])
20
21 y_pred = knn.predict(X_train[X_train['date_block_num']==33].drop(tc+['ID', 'target'],axis = 1))
~\AppData\Roaming\Python\Python36\site-packages\sklearn\neighbors\base.py in fit(self, X, y)
871 X, y = check_X_y(X, y, "csr", multi_output=True)
872 self._y = y
--> 873 return self._fit(X)
874
875
~\AppData\Roaming\Python\Python36\site-packages\sklearn\neighbors\base.py in _fit(self, X)
237 # and KDTree is generally faster when available
238 if ((self.n_neighbors is None or
--> 239 self.n_neighbors < self._fit_X.shape[0] // 2) and
240 self.metric != 'precomputed'):
241 if self.effective_metric_ in VALID_METRICS['kd_tree']:
TypeError: '<' not supported between instances of 'dict' and 'int'
Часть данных:
<div>
<style scoped="">
.dataframe tbody tr th:only-of-type {
vertical-align: middle;
}
.dataframe tbody tr th {
vertical-align: top;
}
.dataframe thead th {
text-align: right;
}
</style>
<table class="dataframe" border="1">
<thead>
<tr style="text-align: right;">
<th></th>
<th>shop_avg_item_price_per_category</th>
<th>shop_sum_item_cnt_day_lag_1</th>
<th>shop_avg_item_cnt_day_lag_1</th>
<th>category_avg_item_price_lag_1</th>
<th>shop_avg_item_price_per_category_lag_1</th>
<th>avg_item_cnt_day_lag_2</th>
<th>shop_sum_item_cnt_day_lag_2</th>
<th>shop_avg_item_cnt_day_lag_2</th>
<th>category_avg_item_price_lag_2</th>
<th>shop_avg_item_price_per_category_lag_2</th>
<th>shop_sum_item_cnt_per_category_lag_2</th>
<th>shop_avg_item_cnt_per_category_lag_2</th>
<th>item_price_lag_3</th>
<th>avg_item_price_lag_3</th>
<th>sum_item_cnt_day_lag_3</th>
<th>avg_item_cnt_day_lag_3</th>
<th>shop_sum_item_cnt_day_lag_3</th>
<th>shop_avg_item_cnt_day_lag_3</th>
<th>category_avg_item_price_lag_3</th>
<th>category_sum_item_cnt_day_lag_3</th>
<th>category_avg_item_cnt_day_lag_3</th>
<th>shop_avg_item_price_per_category_lag_3</th>
<th>shop_sum_item_cnt_per_category_lag_3</th>
<th>shop_avg_item_cnt_per_category_lag_3</th>
<th>item_price_lag_4</th>
<th>...</th>
<th>shop_avg_item_price_per_category_lag_4</th>
<th>shop_sum_item_cnt_per_category_lag_4</th>
<th>shop_avg_item_cnt_per_category_lag_4</th>
<th>item_price_lag_6</th>
<th>sum_item_cnt_day_lag_6</th>
<th>avg_item_cnt_day_lag_6</th>
<th>shop_sum_item_cnt_day_lag_6</th>
<th>shop_avg_item_cnt_day_lag_6</th>
<th>category_avg_item_price_lag_6</th>
<th>category_sum_item_cnt_day_lag_6</th>
<th>category_avg_item_cnt_day_lag_6</th>
<th>shop_avg_item_price_per_category_lag_6</th>
<th>shop_sum_item_cnt_per_category_lag_6</th>
<th>shop_avg_item_cnt_per_category_lag_6</th>
<th>target_lag_12</th>
<th>sum_item_cnt_day_lag_12</th>
<th>avg_item_cnt_day_lag_12</th>
<th>shop_sum_item_cnt_day_lag_12</th>
<th>shop_avg_item_cnt_day_lag_12</th>
<th>category_avg_item_price_lag_12</th>
<th>category_sum_item_cnt_day_lag_12</th>
<th>category_avg_item_cnt_day_lag_12</th>
<th>shop_avg_item_price_per_category_lag_12</th>
<th>shop_sum_item_cnt_per_category_lag_12</th>
<th>shop_avg_item_cnt_per_category_lag_12</th>
</tr>
</thead>
<tbody>
<tr>
<th>4488710</th>
<td>1255.026825</td>
<td>1322.0</td>
<td>0.156025</td>
<td>1327.493573</td>
<td>1326.642773</td>
<td>0.044444</td>
<td>862.0</td>
<td>0.106564</td>
<td>1394.281350</td>
<td>1393.509145</td>
<td>140.0</td>
<td>0.571429</td>
<td>1461.228571</td>
<td>1393.537888</td>
<td>6.0</td>
<td>0.130435</td>
<td>795.0</td>
<td>0.098893</td>
<td>1354.019638</td>
<td>14113.0</td>
<td>1.203154</td>
<td>1350.446936</td>
<td>191.0</td>
<td>0.749020</td>
<td>1461.228571</td>
<td>...</td>
<td>1252.440779</td>
<td>257.0</td>
<td>0.992278</td>
<td>1461.228571</td>
<td>3.0</td>
<td>0.065217</td>
<td>807.0</td>
<td>0.096014</td>
<td>1236.387840</td>
<td>7497.0</td>
<td>0.681917</td>
<td>1235.285393</td>
<td>138.0</td>
<td>0.577406</td>
<td>1.0</td>
<td>7.0</td>
<td>0.155556</td>
<td>1146.0</td>
<td>0.14122</td>
<td>1201.195942</td>
<td>8983.0</td>
<td>0.849456</td>
<td>1202.996944</td>
<td>114.0</td>
<td>0.485106</td>
</tr>
<tr>
<th>4488711</th>
<td>201.385813</td>
<td>1322.0</td>
<td>0.156025</td>
<td>202.264922</td>
<td>202.541912</td>
<td>1.022222</td>
<td>862.0</td>
<td>0.106564</td>
<td>196.136342</td>
<td>195.771183</td>
<td>49.0</td>
<td>0.022940</td>
<td>262.545138</td>
<td>234.362681</td>
<td>24.0</td>
<td>0.521739</td>
<td>795.0</td>
<td>0.098893</td>
<td>199.127951</td>
<td>24173.0</td>
<td>0.254233</td>
<td>198.815861</td>
<td>44.0</td>
<td>0.021287</td>
<td>262.545138</td>
<td>...</td>
<td>195.339467</td>
<td>56.0</td>
<td>0.025998</td>
<td>262.545138</td>
<td>41.0</td>
<td>0.891304</td>
<td>807.0</td>
<td>0.096014</td>
<td>191.044051</td>
<td>24806.0</td>
<td>0.228500</td>
<td>190.384669</td>
<td>92.0</td>
<td>0.038983</td>
<td>0.0</td>
<td>0.0</td>
<td>0.000000</td>
<td>0.0</td>
<td>0.00000</td>
<td>199.562656</td>
<td>0.0</td>
<td>0.000000</td>
<td>199.726151</td>
<td>0.0</td>
<td>0.000000</td>
</tr>
<tr>
<th>4488712</th>
<td>356.645697</td>
<td>1322.0</td>
<td>0.156025</td>
<td>356.371940</td>
<td>354.899145</td>
<td>0.600000</td>
<td>862.0</td>
<td>0.106564</td>
<td>341.460339</td>
<td>338.698243</td>
<td>42.0</td>
<td>0.046823</td>
<td>518.387881</td>
<td>530.473848</td>
<td>25.0</td>
<td>0.543478</td>
<td>795.0</td>
<td>0.098893</td>
<td>349.472332</td>
<td>6950.0</td>
<td>0.171495</td>
<td>348.349523</td>
<td>29.0</td>
<td>0.032917</td>
<td>518.387881</td>
<td>...</td>
<td>341.946769</td>
<td>24.0</td>
<td>0.028302</td>
<td>518.387881</td>
<td>14.0</td>
<td>0.304348</td>
<td>807.0</td>
<td>0.096014</td>
<td>336.559242</td>
<td>8623.0</td>
<td>0.224768</td>
<td>333.367923</td>
<td>21.0</td>
<td>0.025180</td>
<td>0.0</td>
<td>0.0</td>
<td>0.000000</td>
<td>0.0</td>
<td>0.00000</td>
<td>352.298165</td>
<td>0.0</td>
<td>0.000000</td>
<td>352.452030</td>
<td>0.0</td>
<td>0.000000</td>
</tr>
<tr>
<th>4488713</th>
<td>201.385813</td>
<td>1322.0</td>
<td>0.156025</td>
<td>202.264922</td>
<td>202.541912</td>
<td>1.800000</td>
<td>862.0</td>
<td>0.106564</td>
<td>196.136342</td>
<td>195.771183</td>
<td>49.0</td>
<td>0.022940</td>
<td>221.330209</td>
<td>205.940671</td>
<td>58.0</td>
<td>1.260870</td>
<td>795.0</td>
<td>0.098893</td>
<td>199.127951</td>
<td>24173.0</td>
<td>0.254233</td>
<td>198.815861</td>
<td>44.0</td>
<td>0.021287</td>
<td>221.330209</td>
<td>...</td>
<td>195.339467</td>
<td>56.0</td>
<td>0.025998</td>
<td>221.330209</td>
<td>87.0</td>
<td>1.891304</td>
<td>807.0</td>
<td>0.096014</td>
<td>191.044051</td>
<td>24806.0</td>
<td>0.228500</td>
<td>190.384669</td>
<td>92.0</td>
<td>0.038983</td>
<td>0.0</td>
<td>299.0</td>
<td>6.644444</td>
<td>1146.0</td>
<td>0.14122</td>
<td>179.841439</td>
<td>33489.0</td>
<td>0.310860</td>
<td>178.901545</td>
<td>175.0</td>
<td>0.073099</td>
</tr>
<tr>
<th>4488714</th>
<td>356.645697</td>
<td>1322.0</td>
<td>0.156025</td>
<td>356.371940</td>
<td>354.899145</td>
<td>0.333333</td>
<td>862.0</td>
<td>0.106564</td>
<td>341.460339</td>
<td>338.698243</td>
<td>42.0</td>
<td>0.046823</td>
<td>245.228381</td>
<td>218.896182</td>
<td>33.0</td>
<td>0.717391</td>
<td>795.0</td>
<td>0.098893</td>
<td>349.472332</td>
<td>6950.0</td>
<td>0.171495</td>
<td>348.349523</td>
<td>29.0</td>
<td>0.032917</td>
<td>245.228381</td>
<td>...</td>
<td>341.946769</td>
<td>24.0</td>
<td>0.028302</td>
<td>245.228381</td>
<td>46.0</td>
<td>1.000000</td>
<td>807.0</td>
<td>0.096014</td>
<td>336.559242</td>
<td>8623.0</td>
<td>0.224768</td>
<td>333.367923</td>
<td>21.0</td>
<td>0.025180</td>
<td>1.0</td>
<td>61.0</td>
<td>1.355556</td>
<td>1146.0</td>
<td>0.14122</td>
<td>338.532581</td>
<td>6094.0</td>
<td>0.199738</td>
<td>333.080366</td>
<td>52.0</td>
<td>0.076696</td>
</tr>
</tbody>
</table>
<p>5 rows × 58 columns</p>
</div>
Все типы данных являются float32.Это ошибка от KNeignoborsRegressor или от data или hyperopt и как ее исправить?Спасибо.