Построить все прогнозы модели мульти-метки классификации - PullRequest
0 голосов
/ 05 апреля 2020

Я хотел бы построить график ввода и вывода модели, которую я пытаюсь обучить:

форма входных данных:

processed_data.shape
(100, 64, 256, 2)

Как это выглядит:

processed_data
array([[[[ 1.93965047e+04,  8.49532852e-01],
         [ 1.93965047e+04,  8.49463479e-01],

Форма выходных данных:

output.shape
(100, 6)

Вывод - это, в основном, вероятности каждой метки

output = model.predict(processed_data)

output
array([[0.53827614, 0.64929205, 0.48180097, 0.50065327, 0.43016508,
        0.50453395]

Я бы хотел как-то построить для каждого экземпляра в обработанных данных предсказанные вероятности классов (поскольку это проблема классификации по нескольким меткам), но я изо всех сил пытаюсь это сделать. Итак, как я могу построить обработанные данные, но не уверен, как изобразить вероятности для каждого случая ввода. Я хотел бы иметь возможность пометить все 6 возможных классов на каждом из выходных. Я немного растерялся ... Есть предложения?

Пока я только строю график ввода: shape = output.shape [0]

for i in range(it):
    fig,axs = plt.subplots(5,2,figsize=(10,10))

    if isinstance(data,list): 
        inp = data[i]
        outp = output[i]
    else: 
        inp = data
        outp = output

    for j in range(5):
        r = randint(0,shape)
        axs[j,0].imshow(inp[r,...,0]); 
        axs[j,0].title.set_text('Input {}'.format(r))

1 Ответ

1 голос
/ 05 апреля 2020

Теперь я отредактировал свой ответ, чтобы лучше понять вопрос. Этот код будет отображать изображения вместе с выводом.

import matplotlib.image as mpimg
import numpy as np
import matplotlib.pyplot as plt

img_paths = ['../python/imgs/Image001.png',
             '../python/imgs/Image002.png',
             '../python/imgs/Image003.png',
             '../python/imgs/Image004.png',
             '../python/imgs/Image005.png']

input  = np.array([mpimg.imread(path) for path in img_paths])
output = np.random.rand(5, 6)

print(input.shape, output.shape)

fig, axs = plt.subplots(2, 5, figsize=(8, 4), sharey = 'row')

for i, sample in enumerate(range(5)):
    o = output[sample]

    axs[0,i].set_title(f'Sample {sample + 1}')
    axs[0,i].imshow(input[i,:])
    axs[0,i].axis('off')

    axs[1,i].bar(range(6), o)
    axs[1,i].set_xticks(range(6))
    axs[1,i].set_xticklabels([f'{i+1}' for i in range(6)])

plt.show()

Вывод:

(5, 1510, 2560, 4) (5, 6)

output plot

Важной частью является plt.subplots вызов, где вы можете создать сетку графиков так, как вам нравится (если вы хотите на самом деле построить все 100 изображений, вы, вероятно, предпочтете вертикальное расположение).

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...