Поиск закодированной метки в ndarray - PullRequest
0 голосов
/ 06 апреля 2020

У меня есть ndarray под названием labels с формой (6000, 8). Это 6000 массивов с горячим кодированием в 8 категориях. Я хочу найти ярлыки, которые выглядят так:

[1,0,0,0,0,0,0,0]

, а затем попытался сделать так

np.where(labels==[1,0,0,0,0,0,0,0,0])

, но это не дает ожидаемого результата

1 Ответ

2 голосов
/ 06 апреля 2020

Вам нужно all вдоль второй оси:

np.where((labels == [1,0,0,0,0,0,0,0]).all(1))

См. Этот меньший пример:

labels = np.array([[1,0,0,1,0,0,0,0], 
                   [0,0,0,0,0,1,1,0], 
                   [1,0,0,0,0,0,0,0], 
                   [0,0,0,0,0,0,0,1]])

(labels == [1,0,0,0,0,0,0,0])

array([[ True,  True,  True, False,  True,  True,  True,  True],
       [False,  True,  True,  True,  True, False, False,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [False,  True,  True,  True,  True,  True,  True, False]])

Обратите внимание, что выше Сравнение просто возвращает массив той же формы, что и labels, поскольку сравнение имело место вдоль строк labels. Вам нужно объединиться с all, чтобы проверить, все ли элементы в строке True:

(labels == [1,0,0,0,0,0,0,0]).all(1)
 #array([False, False,  True, False])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...