Pytorch: Как построить прогнозный результат задачи сегментации, если размер пакета больше 1? - PullRequest
0 голосов
/ 03 мая 2019

У меня есть сомнение и вопрос о выводе сюжета разных партий в теме сегментации.

Ниже приведен фрагмент вероятности каждого класса и результат прогноза.

Я уверен, что на графике проб строится одна партия, но не уверен насчет прогноза, когда я получил torch.argmax (выводы, 1). Я составил график argmax для одной партии, в то время как выходные данные сети имеют размер [10,4,256,256].

Кроме того, мне интересно, как я могу построить прогноз для всех партий, пока размер моей партии равен 10.

outputs = model(t_image)

fig, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(nrows=1, ncols=5, sharex=True, sharey=True, figsize=(6,6))

img1 = ax1.imshow(torch.exp(outputs[0,0,:,:]).detach().cpu(), cmap = 'jet')
ax1.set_title("prob class 0")

img2 = ax2.imshow(torch.exp(outputs[0,1,:,:]).detach().cpu(), cmap = 'jet')
ax2.set_title("prob class 1")

img3 = ax3.imshow(torch.exp(outputs[0,2,:,:]).detach().cpu(), cmap = 'jet')
ax3.set_title("prob class 2")

img4 = ax4.imshow(torch.exp(outputs[0,3,:,:]).detach().cpu(), cmap = 'jet')
ax4.set_title("prob class 3")

img5 = ax5.imshow(torch.argmax(outputs, 1).detach().cpu().squeeze(), cmap = 'jet')
ax5.set_title("predicted")

1 Ответ

1 голос
/ 04 мая 2019

Не уверен насчет того, что вы спрашиваете.Предполагая, что вы используете макет данных NCHW, вы получите 10 выборок на пакет, 4 канала (каждый канал для другого класса) и разрешение 256x256, а затем первые 4 графика отображают оценки четырех классов.

Для 5-го графика ваш torch.argmax(outputs, 1).detach().cpu().squeeze() даст вам изображение 10x256x256, которое является результатом предсказания класса для всех 10 изображений в пакете, и matplotlib не может правильно построить его напрямую.Таким образом, вы захотите сделать torch.argmax(outputs[0,:,:,:], 0).detach().cpu().squeeze(), который даст вам карту 256x256, которую вы можете построить.

Поскольку результат будет в диапазоне от 0 до 3, который представляет 4 класса (и может отображаться какочень тусклое изображение), обычно люди используют палитру, чтобы раскрасить графики.Здесь приведен пример здесь и выглядит как строка cityscapes_map[p] в этом примере.

Для построения всех 10 почему бы не написать цикл for:

for i in range(outputs.size(0)):
    # do whatever you do with outputs[i, ...]
    # ...
    plt.show()

и просмотреть каждый результат в пакете один за другим.Существует также возможность иметь 10 строк в вашем сюжете, если ваш экран достаточно большой.

...