Отображение прогнозируемых выходных данных сегментации в Pyplot - PullRequest
0 голосов
/ 28 сентября 2019

Я пытаюсь построить массив выходных данных из сети.Массив содержит данные mnist, которые были в горячем виде закодированы для сегментации изображения.

его форма: (28,28,11)

Таким образом, для каждого места в исходном изображении, где значение пикселя равно 0, горячее кодирование поместит массив [0 0 0 0 0 0 0 0 0 0 1], указывая, что этот пиксель пуст.

Если, с другой стороны, значение пикселя> 0, массив с горячей точкой будет отображать классификацию всего изображения.

EX: Если изображение mnist имеет значение 2, каждый пиксель, значение которого> 0, будет преобразован в массив [0 0 1 0 0 0 0 0 0 0 0].

Мне интересно, есть ли способ отобразить такой массив, каждый элемент которого состоит из одного горячего массива.

Я пытался просто использовать plt.imshow для данных, однако он выдает ошибку, говорящую, что «Ошибка типа: Неверные измерения для данных изображения»

Вот код I 'я работаю с

from keras import applications
from keras.preprocessing.image import ImageDataGenerator
from keras import optimizers
import skimage.transform
import cv2
import sys
from keras import Input
from keras import backend as K
from keras.utils import np_utils
from keras.models import Sequential, Model 
from keras.utils import to_categorical
from keras.losses import categorical_crossentropy
from keras.optimizers import adam
from keras.layers import Conv2D, Dense, MaxPooling2D, Flatten, Dropout, GlobalAveragePooling2D
from keras.datasets import cifar10
from keras.datasets import mnist
from keras.utils import np_utils
import matplotlib.pyplot as plt
import numpy as np
np.set_printoptions(threshold=sys.maxsize)

(x_train, y_train), (x_test, y_test) = mnist.load_data()

y_train = y_train[:10]

data = np.random.choice(255, (10,128,128))


## do you calculation of brightness here
## and expand it to one row per pixel
arr = data.reshape(-1,1)/255
## repeat labels to match the expanded pixel
labels = y_train.repeat(128*128).reshape(-1,1)

ind_row = np.arange(len(arr))
ind_col = np.where(arr>0, labels, 10).ravel()

one_hot_coded_arr = np.zeros((len(arr), 11))
one_hot_coded_arr[ind_row,ind_col]=1

## convert back to desired shape
one_hot_coded_arr = one_hot_coded_arr.reshape(-1, 128,128,11)
#print(one_hot_coded_arr[:28,:])
print(one_hot_coded_arr.shape)


plt.imshow(one_hot_coded_arr, interpolation='nearest')
plt.axis("off")
plt.show()

Я хочу отобразить изображение примерно так: https://documentation.sas.com/api/docsets/casdlpg/8.4/content/images/mnistout2.png

Но я продолжаю сталкиваться с ошибкой «Ошибка типа: Неверные размеры для данных изображения»

Любая помощь будет потрясающей, спасибо!

1 Ответ

0 голосов
/ 28 сентября 2019

У вас слишком много измерений.matplotlib.pyplot только графики 2 измерений (x, y).

Поэтому сначала вы должны выбрать свое изображение для отображения, т.е. output[n].Затем, поскольку он горячо закодирован, используйте функцию np.argmax(output[n], axis=-1) для «unencode».

В качестве альтернативы, просто выберите нужный вам слой output[n,:,:,l].

Дайте мне знать, если это работает.

...