Как увеличить количество каналов тензора pytorch? - PullRequest
0 голосов
/ 01 августа 2020

У меня есть тензор pytorch формата [N, 2, H, W], где 2 - количество каналов. Однако для модели, которую я использую (предварительно обученный resnet18), я должен иметь размеры [N, 3, H, W]. Как увеличить количество каналов с 2 до 3?

1 Ответ

0 голосов
/ 01 августа 2020

Сохраните ваше двухканальное изображение в оттенках серого на диск, а затем выполните следующие действия:

import torch
from PIL import Image
from torchvision.models import resnet18

from torchvision import transforms

transform = transforms.Compose([            
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)])

img= Image.open('pic.jpg').convert('RGB')
tensor= transform(img)
tensor= torch.unsqueeze(tensor, 0).float().cuda()

resnet_18_model= resnet18(pretrained= True).cuda() # resnet18()
resnet_18_model.eval()
output= resnet_18_model(tensor)

output= torch.argmax(output)
print('Class Number: ', output.item())
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...