Как выбрать только конкретную цифру из набора данных MNIST, предоставленного Keras? - PullRequest
0 голосов
/ 06 июля 2018

В настоящее время я тренирую нейронную сеть прямой связи с набором данных MNIST с использованием Keras. Я загружаю набор данных в формате

(X_train, Y_train), (X_test, Y_test) = mnist.load_data()

но тогда я только хочу обучить свою модель, используя цифры 0 и 4, а не все. Как выбрать только 2 цифры? Я довольно новичок в Python и могу понять, как фильтровать набор данных mnist ...

Ответы [ 3 ]

0 голосов
/ 06 июля 2018

Y_train и Y_test дают вам метки изображений, вы можете использовать их с numpy.where, чтобы отфильтровать подмножество меток с 0 и 4.Все ваши переменные являются пустыми массивами, так что вы можете просто сделать:

import numpy as np

train_filter = np.where((Y_train == 0 ) | (Y_train == 4))
test_filter = np.where((Y_test == 0) | (Y_test == 4))

, и вы можете использовать эти фильтры, чтобы получить подмножество массивов по индексу.

X_train, Y_train = X_train[train_filter], Y_train[train_filter]
X_test, Y_test = X_test[test_filter], Y_test[test_filter]

Если вы заинтересованыв более чем 2 метках, синтаксис может сильно пострадать от того, где и или.Таким образом, вы также можете использовать numpy.isin для создания масок.

train_mask = np.isin(Y_train, [0, 4])
test_mask = np.isin(Y_test, [0, 4])

Вы можете использовать эти маски для логического индексирования, как и раньше.

0 голосов
/ 15 апреля 2019

при использовании Y_train = Y_train[train_mask] повышает InvalidArgumentError, когда цифры не являются последовательными и начинаются с 0 (keras ожидает последовательный диапазон меток, начинающийся с 0)

решение (для двух цифр):

train_mask = np.isin(Y_train, [2,8])
test_mask = np.isin(Y_test, [2,8])

X_train, Y_train = X_train[train_mask], np.array(Y_train[train_mask] == 8)
X_test, Y_test = X_test[test_mask], np.array(Y_test[test_mask] == 8)
0 голосов
/ 06 июля 2018

у вас есть файлы меток вместе с поездом и тестом:

train_images = mnist.train_images()
train_labels = mnist.train_labels()

test_images = mnist.test_images()
test_labels = mnist.test_labels()

вы можете использовать их вместе с простым пониманием списка, чтобы отфильтровать ваш набор данных

zero_four_test = [test_images[key] for (key, label) in enumerate(test_labels) if int(label) == 0 or int(label) == 4]
...