У меня проблема с UMAP(...).transform()
.Я загрузил модель с trans = pickle.load()
.
ipdb> trans
UMAP(a=None, angular_rp_forest=False, b=None, init='spectral',
learning_rate=1.0, local_connectivity=1.0, metric='correlation',
metric_kwds=None, min_dist=0.5, n_components=3, n_epochs=None,
n_neighbors=3, negative_sample_rate=5, random_state=None,
repulsion_strength=1.0, set_op_mix_ratio=1.0, spread=1.0,
target_metric='categorical', target_metric_kwds=None,
target_n_neighbors=-1, target_weight=0.5, transform_queue_size=4.0,
transform_seed=42, verbose=False)
Поскольку UMAP использует пакет numba
, у меня возникла проблема с разрешением типов.Здесь показана ошибка.
ipdb> trans.transform(X_test.reshape(X_test.shape[0], -1))
*** numba.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of type(CPUDispatcher(<function rdist at 0x7f36c7ce4e18>)) with parameters (array(float32, 1d, C), array(float64, 1d, C))
Known signatures:
* (array(float32, 1d, A), array(float32, 1d, A)) -> float32
* parameterized
[1] During: resolving callee type: type(CPUDispatcher(<function rdist at 0x7f36c7ce4e18>))
[2] During: typing of call at /home/infinity/anaconda3/envs/HFT/lib/python3.6/site-packages/umap/umap_.py (776)
File "../../../anaconda3/envs/HFT/lib/python3.6/site-packages/umap/umap_.py", line 776:
def optimize_layout(
<source elided>
dist_squared = rdist(current, other)
^
This is not usually a problem with Numba itself but instead often caused by
the use of unsupported features or an issue in resolving types.
To see Python/NumPy features supported by the latest release of Numba visit:
http://numba.pydata.org/numba-doc/dev/reference/pysupported.html
and
http://numba.pydata.org/numba-doc/dev/reference/numpysupported.html
For more information about typing errors and how to debug them visit:
http://numba.pydata.org/numba-doc/latest/user/troubleshoot.html#my-code-doesn-t-compile
If you think your code should work with Numba, please report the error message
and traceback, along with a minimal reproducer at:
https://github.com/numba/numba/issues/new
Вот пример X_test
:
array([[[-7.84867750e-05, -3.92410776e-05, -1.17713994e-04, ...,
0.00000000e+00, 0.00000000e+00, -5.88910810e-05],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
3.22361913e+01, 3.36224857e+01, 0.00000000e+00],
[-1.17741714e-04, -1.56979711e-04, 0.00000000e+00, ...,
-3.22361913e+01, -3.36224857e+01, -2.35599012e-04],
...,
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
0.00000000e+00, 0.00000000e+00, 3.92395378e-05],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00]],
[[ 0.00000000e+00, 0.00000000e+00, 7.98615775e-04, ...,
0.00000000e+00, 0.00000000e+00, -2.66778461e-04],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
[ 1.10377388e-03, 1.86421709e-03, 3.04066897e-04, ...,
0.00000000e+00, 0.00000000e+00, 5.52533571e-04],
...,
[ 0.00000000e+00, 1.14074947e-04, 3.80220908e-05, ...,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
[-1.36996749e-03, -4.18338442e-04, -4.18322533e-04, ...,
0.00000000e+00, 0.00000000e+00, -1.21918711e-03],
[-3.80814562e-05, -1.10373187e-03, -3.80380000e-05, ...,
0.00000000e+00, 0.00000000e+00, -1.90615976e-05]]])
X.shape is (5839, 45, 41)
Как я могу исправить эту ошибку типа?
ОБНОВЛЕНИЕ
Я уведомил, что проблема с
@numba.njit("f4(f4[:],f4[:])", fastmath=True)
def rdist(x, y):
"""Reduced Euclidean distance.
Parameters
----------
x: array of shape (embedding_dim,)
y: array of shape (embedding_dim,)
Returns
-------
The squared euclidean distance between x and y
"""
result = 0.0
for i in range(x.shape[0]):
result += (x[i] - y[i]) ** 2
return result
x = [-4.3200183, 0.2443985, 0.24419938]
y = [-3.7033827, 0.91038215, 4.292648 ]
test = rdist(x,y)
Однако я не квалифицирован, чтобы решить проблему.
Если яЗапустив приведенный выше код, я получил следующую ошибку TypeError: No matching definition for argument type(s) reflected list(float64), reflected list(float64)