Ошибка на входе KDTree ValueError: установка элемента массива с последовательностью - PullRequest
0 голосов
/ 16 июня 2020

Я пытаюсь использовать KDTree для классификации некоторых входных данных в следующем коде:

tree = KDTree(train_x, leaf_size=2)
dist, ind = tree.query(train_x[:1], k=3)
print(ind)
print(dist)

Это ошибка, которую я получаю

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
TypeError: only size-1 arrays can be converted to Python scalars

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
<ipython-input-39-3faca41d63b2> in <module>()
----> 1 tree = KDTree(train_x, leaf_size=2)              # doctest: +SKIP
      2 dist, ind = tree.query(train_x[:1], k=3)                # doctest: +SKIP
      3 print(ind)  # indices of 3 closest neighbors
      4 print(dist)  # distances to 3 closest neighbors

sklearn/neighbors/_binary_tree.pxi in sklearn.neighbors._kd_tree.BinaryTree.__init__()

/usr/local/lib/python3.6/dist-packages/numpy/core/_asarray.py in asarray(a, dtype, order)
     83 
     84     """
---> 85     return array(a, dtype, copy=False, order=order)
     86 
     87 

ValueError: setting an array element with a sequence.

Вот как выглядит мой train_x например,

array([array([-0.0473686 ,  0.21476793, -0.24364506, ...,  0.25211015,
        0.09572177,  0.03847923], dtype=float32),
       array([ 0.00438724,  0.17182858, -0.27784246, ...,  0.18208861,
       -0.24425837, -0.1276848 ], dtype=float32),
       array([-0.07410974, -0.10420805, -0.20287056, ...,  0.21935067,
        0.04020798,  0.04056953], dtype=float32),
       array([ 0.00465234,  0.08095299, -0.31916058, ...,  0.11371485,
        0.05105788, -0.00181724], dtype=float32),
       array([-0.14058642,  0.0477878 , -0.18549545, ...,  0.23005083,
       -0.03298192,  0.07319082], dtype=float32),
       array([ 0.08589569, -0.13067536,  0.0169823 , ...,  0.25152165,
        0.1758727 , -0.13552606], dtype=float32)], dtype=object)

Может ли кто-нибудь сказать мне, почему я получаю эту ошибку?

...