Как решить эту ошибку, TypeError: take (): аргумент 'index' (позиция 1) должен быть Tensor, а не numpy.ndarray? - PullRequest
0 голосов
/ 09 июля 2019

Ну, я пытаюсь создать классификатор изображений, используя SVM.До этого я использовал CNN, и именно поэтому использовался PyTorch.Но теперь, когда я хочу использовать SVM, я должен использовать Sci-Kit Learn.Итак, для начала мне нужно разделить мой набор данных для обучения части и проверки части.Мы будем использовать train_test_split для этого.Я занимаюсь программированием в Google Colab и храню свои изображения в некоторых папках на Google Диске.Это моя часть загрузчика данных -

# choose the training and test datasets
train_data = datasets.ImageFolder(data+"/train", transform=transform_train)
test_data = datasets.ImageFolder(data+"/val", transform = transform_test)
#n_classes = test_data.shape[1]
n_classes = len(test_data.classes)
print(n_classes)

batch_size = 32

dataloader_train = torch.utils.data.DataLoader(train_data, batch_size, shuffle=True, num_workers=2)
dataloader_test = torch.utils.data.DataLoader(test_data, batch_size, num_workers=2)

Далее я показываю ее и делю набор данных на это -

images, labels = next(iter(dataloader_train))
imshow_numpy(images[0].numpy())
print(images.shape)
X_train, X_test, y_train, y_test = train_test_split(images, labels)

Но эта строка ----> 5 X_train, X_test, y_train, y_test = train_test_split(images, labels) выдает мне ошибку.Я не знаю, как это решить.Кто-нибудь знает об этом?

...