Преобразуйте изображения в белое на черном и прогнозируйте - PullRequest
0 голосов
/ 22 октября 2019

Я хочу предсказать новые настроенные изображения с обученной моделью LeNet от здесь . Настраиваемые изображения черно-белые, поэтому мне нужно преобразовать их в черно-белые. enter image description here

# Load & transform image
ori_img = Image.open('./test/2.png').convert('L')
img = np.invert(ori_img) #Transform images to white on black
t = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
img = torch.autograd.Variable(t(img).unsqueeze(0))
ori_img.close()

# Predict
model.eval()
output = model(img)
pred = output.data.max(1, keepdim=True)[1][0][0]
print('Prediction: {}'.format(pred))

Результат, который я получил:

TypeError                                 Traceback (most recent call last)
<ipython-input-182-abbffa2ce0d8> in <module>
      7     transforms.Normalize((0.1307,), (0.3081,))
      8 ])
----> 9 img = torch.autograd.Variable(t(img).unsqueeze(0))
     10 ori_img.close()

~/.local/lib/python3.6/site-packages/torchvision/transforms/transforms.py in __call__(self, img)
     59     def __call__(self, img):
     60         for t in self.transforms:
---> 61             img = t(img)
     62         return img
     63 

~/.local/lib/python3.6/site-packages/torchvision/transforms/transforms.py in __call__(self, img)
    196             PIL Image: Rescaled image.
    197         """
--> 198         return F.resize(img, self.size, self.interpolation)
    199 
    200     def __repr__(self):

~/.local/lib/python3.6/site-packages/torchvision/transforms/functional.py in resize(img, size, interpolation)
    236     """
    237     if not _is_pil_image(img):
--> 238         raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
    239     if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)):
    240         raise TypeError('Got inappropriate size arg: {}'.format(size))

TypeError: img should be PIL Image. Got <class 'numpy.ndarray'>

Когда я комментирую img = np.invert(ori_img) Я не получаю ошибок, но все результаты прогноза 2 с.

Кто-нибудь может помочь? Большое спасибо.

1 Ответ

1 голос
/ 22 октября 2019

Вы можете использовать эту функцию: PIL.Image.fromarray для создания изображения PIL из массива numpy, а затем вы можете использовать функцию PIL.ImageOps.invert для инвертирования цветов. ,Тогда ваша img переменная должна быть правильного типа и инвертирована.

...