Проблема получения ярлыков обучающего набора - PullRequest
2 голосов
/ 25 октября 2019

Я использовал функцию train_test_split, чтобы разделить мои данные на X_train, X_test, y_train, y_test, а затем использовал utils.data.DataLoader, чтобы передать их на мой CNN, но проблема в том, что я делаюне знаю, как получить доступ к тензору моих меток для создания матрицы путаницы и сравнения их с моим тензором прогноза. Я знаю, это основной вопрос, но в любом случае ваша помощь приветствуется.

X_train, X_test, y_train, y_test = train_test_split(faces, emotions, test_size=0.1, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=41)

, и я использовал

train = torch.utils.data.TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
train_loader = torch.utils.data.DataLoader(train, batch_size=100, shuffle=True)

для подачи данных в мою сеть. Кажется, вы можете получить доступ к своим ярлыкам, простовведите атрибут target после вашего train_set, например train_set.targets, но у меня это не работает. Как я могу получить свои этикетки?

1 Ответ

0 голосов
/ 25 октября 2019

Объект DataLoader PyTorch примерно используется следующим образом:

for i, (inputs, labels) in enumerate(dataloader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

В общем, я бы предложил использовать два DataLoader, один для обучения и один для тестирования / валидации. Поскольку вы хотите создать запутанную матрицу, вы можете получить доступ к своим меткам просто с помощью массива numpy y_train и своего прогноза preds, например, путем объединения их внутри цикла с массивом numpy.

Для получения дополнительной информации оКак использовать DataLoader, я предлагаю посмотреть этот очень хороший учебник: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py

и

https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

...