Среднее изображение MNIST - PullRequest
0 голосов
/ 03 марта 2019

Работая с набором данных MNIST, я пытаюсь найти среднее изображение для каждой отдельной цифры (0-9).Следующий код дает мне каждое отдельное изображение из набора данных, но я не уверен, как получить среднее значение для каждого класса (0-9)

data = io.loadmat('mnist-original.mat')

x, y = data['data'].T, data['label'].T

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.5)


a=np.unique(y, return_index=True)
b = a[1]

plt.figure(figsize=(15,4.5))
for i in b:
    img=x[i][:].reshape(28,28)
    plt.imshow(img)
    plt.show()  

1 Ответ

0 голосов
/ 03 марта 2019

Пакет numpy_indexed (отказ от ответственности: я его автор) предоставляет этот тип функциональности в векторизованном виде:

import numpy_indexed as npi
digits, means = npi.group_by(y).mean(x)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...