Я хочу предсказать новые настроенные изображения с обученной моделью LeNet от здесь . Настраиваемые изображения черно-белые, поэтому мне нужно преобразовать их в черно-белые.
# Load & transform image
ori_img = Image.open('./test/2.png').convert('L')
img = np.invert(ori_img) #Transform images to white on black
t = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
img = torch.autograd.Variable(t(img).unsqueeze(0))
ori_img.close()
# Predict
model.eval()
output = model(img)
pred = output.data.max(1, keepdim=True)[1][0][0]
print('Prediction: {}'.format(pred))
Результат, который я получил:
TypeError Traceback (most recent call last)
<ipython-input-182-abbffa2ce0d8> in <module>
7 transforms.Normalize((0.1307,), (0.3081,))
8 ])
----> 9 img = torch.autograd.Variable(t(img).unsqueeze(0))
10 ori_img.close()
~/.local/lib/python3.6/site-packages/torchvision/transforms/transforms.py in __call__(self, img)
59 def __call__(self, img):
60 for t in self.transforms:
---> 61 img = t(img)
62 return img
63
~/.local/lib/python3.6/site-packages/torchvision/transforms/transforms.py in __call__(self, img)
196 PIL Image: Rescaled image.
197 """
--> 198 return F.resize(img, self.size, self.interpolation)
199
200 def __repr__(self):
~/.local/lib/python3.6/site-packages/torchvision/transforms/functional.py in resize(img, size, interpolation)
236 """
237 if not _is_pil_image(img):
--> 238 raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
239 if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)):
240 raise TypeError('Got inappropriate size arg: {}'.format(size))
TypeError: img should be PIL Image. Got <class 'numpy.ndarray'>
Когда я комментирую img = np.invert(ori_img)
Я не получаю ошибок, но все результаты прогноза 2
с.
Кто-нибудь может помочь? Большое спасибо.