Логический вывод C ++ с ale xnet и cv :: imread image - PullRequest
0 голосов
/ 17 января 2020

Я пытаюсь определить с помощью приложения C ++ задачу классификации изображений, используя предварительно подготовленный эль xnet net.
Я успешно вывел изображение собаки, загружающее net с помощью python:

alexnet = torchvision.models.alexnet(pretrained=True)
img = Image.open("dog.jpg")
transform = transforms.Compose([
 transforms.Resize(256),                
 transforms.CenterCrop(224),        
 transforms.ToTensor(),                  
 transforms.Normalize(                   
 mean=[0.485, 0.456, 0.406],         
 std=[0.229, 0.224, 0.225]              
 )])
img_t = transform(img)
batch_t = torch.unsqueeze(img_t, 0)
alexnet.forward(batch_t)
_, index = torch.max(out, 1)

Результат index равен 208, Labrador_retriever, это выглядит хорошо.
Затем я сохраняю net для загрузки из приложения C ++

example = torch.rand(1, 3, 224, 224)
traced_script_module_alex = torch.jit.trace(alexnet, example)
traced_script_module.save("alexnet.pt")

Когда я загружаю на C ++ я получаю неправильный результат:

cv::Mat img = cv::imread("dog.jpg");
cv::resize(img, img, cv::Size(224, 224), cv::INTER_CUBIC);

// Convert the image and label to a tensor.
torch::Tensor img_tensor = torch::from_blob(img.data, { 1, img.rows, img.cols, 3 }, torch::kByte);
img_tensor = img_tensor.permute({ 0, 3, 1, 2 }); // convert to CxHxW
img_tensor = img_tensor.to(torch::kFloat);
std::vector<torch::jit::IValue> input;
input.push_back(img_tensor);
torch::jit::script::Module  module = torch::jit::load("alexnet.pt");
at::Tensor output = module.forward(input).toTensor();
std::cout << output.argmax(1) << '\n';

argmax равен 463, ведро. Я думаю, что я не смотрю на то же изображение; что мне не хватает ...?

1 Ответ

1 голос
/ 17 января 2020

В вашем коде C ++ отсутствует эта часть вашего Python кода:

transform = transforms.Compose([
 transforms.Resize(256),                
 transforms.CenterCrop(224),        
 transforms.ToTensor(),                  
 transforms.Normalize(                   
 mean=[0.485, 0.456, 0.406],         
 std=[0.229, 0.224, 0.225]              
 )])
img_t = transform(img)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...