AttributeError при выборе нулей и единиц из набора данных MNIST с помощью PyTorch - PullRequest
0 голосов
/ 05 августа 2020

Я следую руководству по построению модели для классификации только между 0 и 1 из набора данных MNIST. Однако то, что они предлагают сделать, чтобы выбрать значения 0/1, не работает для меня и вызывает ошибку. Это код, который я использую:

from torch.utils.data import DataLoader

mnist_train = datasets.MNIST("./data", train=True, download=True, transform=transforms.ToTensor())
mnist_test = datasets.MNIST("./data", train=False, download=True, transform=transforms.ToTensor())

train_idx = mnist_train.train_labels <= 1
mnist_train.train_data = mnist_train.train_data[train_idx]
mnist_train.train_labels = mnist_train.train_labels[train_idx]

test_idx = mnist_test.test_labels <= 1
mnist_test.test_data = mnist_test.test_data[test_idx]
mnist_test.test_labels = mnist_test.test_labels[test_idx]

train_loader = DataLoader(mnist_train, batch_size = 100, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size = 100, shuffle=False)

И вот результат, который я получаю, когда запускаю его:

  File "<ipython-input-2-aa7f63047cd9>", line 8, in <module>
    mnist_train.train_data = mnist_train.train_data[train_idx]

AttributeError: can't set attribute

Я также пытался изменить:

mnist_train.train_data = mnist_train.train_data[train_idx]

автор:

try:
    mnist_train.train_data = mnist_train.train_data[train_idx]
except AttributeError:
    mnist_train._train_data = mnist_train.train_data[train_idx]

И добавление этого дополнительного символа «_» в каждом утверждении такого рода, казалось, решало проблему, но позже, когда я пытался обучить модель, я понял, что она не выбрана только нули и единицы. Есть предложения?

1 Ответ

0 голосов
/ 05 августа 2020

Я нашел решение благодаря @ptrblck на форуме PyTorch (https://discuss.pytorch.org/t/how-to-use-one-class-of-number-in-mnist/26276/3):

С последними обновлениями PyTorch новый синтаксис:

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

mnist_train = datasets.MNIST("./data", train=True, download=True, transform=transforms.ToTensor())
mnist_test = datasets.MNIST("./data", train=False, download=True, transform=transforms.ToTensor())

train_idx = mnist_train.train_labels <= 1
mnist_train.data = mnist_train.train_data[train_idx]
mnist_train.targets = mnist_train.train_labels[train_idx]

test_idx = mnist_test.test_labels <= 1
mnist_test.data = mnist_test.test_data[test_idx]
mnist_test.targets = mnist_test.test_labels[test_idx]

train_loader = DataLoader(mnist_train, batch_size = 100, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size = 100, shuffle=False)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...