np.where IndexError исключение - PullRequest
0 голосов
/ 03 мая 2018

У меня очень простой код:

import numpy as np
num_classes = 12
im_pred = np.random.randint(0, num_classes, (224, 244))
img = np.zeros((224, 224, 3))
print(im_pred.shape)
#(224, 244)
print(img.shape)
#(224, 224, 3)
for i in range(num_classes):
    img[np.where(im_pred==i), :] = [225, 0, 0]

Traceback (последний последний вызов):
Файл "", строка 2, в
IndexError: индекс 227 выходит за пределы оси 0 с размером 224

x, y = np.where(im_pred==i)
print(np.max(x), np.max(y))
#223 243

Почему я получаю IndexError? Что касается моего понимания np.where, значения возвращаемых индексов должны быть меньше 224.

Дайте мне знать. Я начинаю удивляться, если установка numpy глючит.

Спасибо.

Ответы [ 2 ]

0 голосов
/ 03 мая 2018

Проблема в том, что вы сделали img и img_pred разных размеров:

im_pred.shape == (224, 244)

, а

img.shape == (224, 224, 3)

Вторые оси имеют разные размеры.

Но как только вы это исправите, необходимо выполнить простую оптимизацию. Здесь нет необходимости в np.where. Просто используйте прямое логическое индексирование:

for i in range(num_classes):
    img[im_pred == i, 0] = 255

Примечание. Я также опускаю два ноля, так как вы инициализируете массив нулями при построении.

0 голосов
/ 03 мая 2018

Нет, Numpy не глючит. Посмотрите, как вы определили im_pred на секунду, вы рисуете случайное целое число от 0 до 11 для массива, который имеет размер 224 на 244. Поэтому причина, по которой он выдает ошибку, заключается в том, что размерность 244 слишком велика для вашего переменная img, которая только 224 на 224 на 3. Я думаю, вы могли иметь в виду, что оба имеют одно и то же первое и второе измерения, что-то вроде

img = np.zeros((224,244,3)) 
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...