Доступ к индексу в numpy - PullRequest
0 голосов
/ 16 апреля 2020

Рассмотрим следующий код:

import numpy as np 
err = abs(np.subtract(unNormalizedTestPredictions, test_labels))
print("max error", max(err) )
tmp = np.where( err > 300000) #tmp is a ndarray
print("large values located at ",tmp) 

Вывод:

макс. Ошибка 334901.5078125 больших значений, расположенных в (массив ([64828]),)

Как я могу посмотреть на эти большие значения (err> 300000) и, что еще лучше, найти значение в test_labels, которое его вызвало? Я полагаю, что np.where говорит мне, что он находится на ndx 64828, но следующий код взрывается. Может ли этот же индекс использоваться в test_labels?

newarr = err[np.array([64828])]

1 Ответ

0 голосов
/ 16 апреля 2020

Вы можете попробовать что-то вроде

new_array = test_labels[err> 300000]

Чтобы получить все значения.

Обратите внимание, что err> 300000 - это логический массив, который будет true, где условие проверяется. Эти индексы должны соответствовать test_labels

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...