построение образцов MNIST - PullRequest
       18

построение образцов MNIST

0 голосов
/ 18 декабря 2018

Я пытаюсь построить 10 выборок из набора данных MNIST.Один из каждой цифры.Вот код:

import sklearn
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets

mnist = datasets.fetch_mldata('MNIST original')
y = mnist.target
X = mnist.data

for i in range(10):
    im_idx = np.argwhere(y == i)[0]
    print(im_idx)
    plottable_image = np.reshape(X[im_idx], (28, 28))
    plt.imshow(plottable_image, cmap='gray_r')
    plt.subplot(2, 5, i + 1)

plt.plot()

По какой-то причине нулевая цифра пропускается на графике.

Почему?

Ответы [ 2 ]

0 голосов
/ 18 декабря 2018

Хорошо, я понял.Проблема заключалась в том, что вы определяли участок после печати imshow.Итак, ваш первый сюжет был перезаписан вторым.Чтобы ваш код работал, просто поменяйте местами порядок ваших команд следующим образом.Кроме того, я не понимаю, почему вы используете plt.plot() в конце.

plt.subplot(2, 5, i + 1) # <-- You have put this command after imshow 
plt.imshow(plottable_image, cmap='gray_r')

Вот еще одна альтернатива для ваших знаний:

fig = plt.figure()

for i in range(10):
    im_idx = np.argwhere(y == i)[0]
    plottable_image = np.reshape(X[im_idx], (28, 28))
    ax = fig.add_subplot(2, 5, i+1)
    ax.imshow(plottable_image, cmap='gray_r')

Вы также можете еще больше сократить Скоттакод (размещен ниже) с использованием следующего:

fig, ax = plt.subplots(2,5)
for i, ax in enumerate(ax.flatten()):
    im_idx = np.argwhere(y == i)[0]
    plottable_image = np.reshape(X[im_idx], (28, 28))
    ax.imshow(plottable_image, cmap='gray_r')

enter image description here

0 голосов
/ 18 декабря 2018

Попробуйте:

import sklearn
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets

mnist = datasets.fetch_mldata('MNIST original')
y = mnist.target
X = mnist.data

fig, ax = plt.subplots(2,5)
ax = ax.flatten()
for i in range(10):
    im_idx = np.argwhere(y == i)[0]
    print(im_idx)
    plottable_image = np.reshape(X[im_idx], (28, 28))
    ax[i].imshow(plottable_image, cmap='gray_r')

Вывод:

enter image description here

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...