Как получить метки, используемые в маске сегментации в AWS Sagemaker - PullRequest
7 голосов
/ 28 мая 2020

Из маски сегментации я пытаюсь извлечь, какие метки представлены в маске.

Это изображение, которое я просматриваю в модели сегментации semanti c в AWS Sagemaker.

Motorbike and everything else background

Код для создание прогнозов и отображение маски.

from sagemaker.predictor import json_serializer, json_deserializer, RealTimePredictor
from sagemaker.content_types import CONTENT_TYPE_CSV, CONTENT_TYPE_JSON

%%time
ss_predict = sagemaker.RealTimePredictor(endpoint=ss_model.endpoint_name, 
                                     sagemaker_session=sess,
                                    content_type = 'image/jpeg',
                                    accept = 'image/png')

return_img = ss_predict.predict(img)

from PIL import Image
import numpy as np
import io

num_labels = 21
mask = np.array(Image.open(io.BytesIO(return_img)))
plt.imshow(mask, vmin=0, vmax=num_labels-1, cmap='jet')
plt.show()

Это изображение представляет собой созданную маску сегментации, которая представляет мотоцикл, а все остальное является фоном.

[Segmented mask[2]

Как видно из кода, существует 21 возможная метка, и в маске использовались 2: одна для мотоцикла, а другая для фона. Сейчас я хотел бы выяснить, как распечатать, какие этикетки фактически использовались в этой маске из 21 возможных вариантов?

1 Ответ

1 голос
/ 17 июня 2020

Где-то у вас должно быть отображение целых чисел меток в классы меток, например

label_map = {0: 'background', 1: 'motorbike', 2: 'train', ...}

Если вы используете набор данных Pascal VO C, это будет (1 = самолет, 2 = велосипед, 3 = птица, 4 = лодка, 5 = bottle, 6 = автобус, 7 = машина, 8 = кошка, 9 = стул, 10 = корова, 11 = обеденный стол, 12 = собака, 13 = лошадь, 14 = мотоцикл, 15 = человек, 16 = растение в горшке, 17 = овца, 18 = диван, 19 = поезд, 20 = телевизор / монитор) - см. здесь: http://host.robots.ox.ac.uk/pascal/VOC/voc2012/segexamples/index.html

Тогда вы можно просто использовать эту карту:

used_classes = np.unique(mask)
for cls in used_classes:
    print("Found class: {}".format(label_map[cls]))
...