Я обучил сеть Маска RCNN, например, по сегментации яблок. Я могу загружать веса и создавать прогнозы для моих тестовых изображений. Сгенерированные маски, кажется, находятся в правильном месте, но сама маска не имеет реальной формы ... она просто выглядит как группа пикселей
Обучение выполняется на основе набора данных из этой бумаги , а вот ссылка github на код, используемый для обучения и генерации весов
код для прогнозирования выглядит следующим образом. (я пропустил части, где я создаю переменные пути и назначаю пути)
import os
import glob
import numpy as np
import pandas as pd
import cv2 as cv
import fileinput
import torch
import torch.utils.data
import torchvision
from data.apple_dataset import AppleDataset
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
import utility.utils as utils
import utility.transforms as T
from PIL import Image
from matplotlib import pyplot as plt
%matplotlib inline
def get_transform(train):
transforms = []
transforms.append(T.ToTensor())
if train:
transforms.append(T.RandomHorizontalFlip(0.5))
return T.Compose(transforms)
def get_maskrcnn_model_instance(num_classes):
# load an instance segmentation model pre-trained pre-trained on COCO
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False)
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
# now get the number of input features for the mask classifier
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
# and replace the mask predictor with a new one
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
return model
num_classes = 2
device = torch.device('cpu')
model = get_maskrcnn_model_instance(num_classes)
checkpoint = torch.load('model_49.pth', map_location=device)
model.load_state_dict(checkpoint['model'], strict=False)
dataset_test = AppleDataset(test_image_files_path, get_transform(train=False))
img, _ = dataset_test[1]
model.eval()
with torch.no_grad():
prediction = model([img.to(device)])
prediction
Image.fromarray(img.mul(255).permute(1, 2, 0).byte().numpy())
(unable to load image here since its over 2MB.
Image.fromarray(prediction[0]['masks'][0, 0].mul(255).byte().cpu().numpy())
Вот ссылка Imgur на исходное изображение ... ниже приведена предсказанная маска для одного из случаи
![Mask output for one instance](https://i.stack.imgur.com/MptQ8.png)
Кроме того, не могли бы вы помочь мне разобраться в структуре данных сгенерированной матрицы предсказания, показанной ниже .. Как я могу получить доступ к маскам так как создать одно изображение со всеми отображаемыми масками ???
[{'boxes': tensor([[ 966.8143, 1633.7491, 1106.7389, 1787.6367],
[1418.7872, 1467.0619, 1732.0828, 1796.1527],
[1608.0396, 2064.6482, 1710.7534, 2206.5535],
[2326.3750, 1690.3418, 2542.2112, 1883.2626],
[2213.2024, 1864.3657, 2299.8933, 1963.0178],
[1112.9083, 1732.5953, 1236.7600, 1823.0170],
[1150.8256, 614.0334, 1218.8584, 711.4094],
[ 942.7086, 794.6043, 1138.2318, 1008.0430],
[1065.4371, 723.0493, 1192.7570, 870.3763],
[1002.3103, 883.4616, 1146.9994, 1006.6841],
[1315.2816, 1680.8625, 1531.3210, 1989.3317],
[1244.5769, 1925.0903, 1459.5417, 2175.3252],
[1725.2191, 2082.6187, 1934.0227, 2274.2952],
[ 936.3065, 1554.3765, 1014.2722, 1659.4229],
[ 934.8851, 1541.3331, 1090.4736, 1657.3751],
[2486.0120, 776.4577, 2547.2329, 847.9725],
[2336.1675, 698.6327, 2508.6492, 921.4550],
[2368.4077, 1954.1102, 2448.4004, 2049.5796],
[1899.1403, 1775.2371, 2035.7561, 1962.6923],
[2176.0664, 1075.1553, 2398.6084, 1267.2555],
[2274.8899, 641.6769, 2395.9634, 791.3353],
[2535.1580, 874.4780, 2642.8213, 966.4614],
[2183.4236, 619.9688, 2288.5676, 758.6825],
[2183.9832, 1122.9382, 2334.9583, 1263.3226],
[1135.7822, 779.0529, 1225.9871, 890.0135],
[ 317.3954, 1328.6995, 397.3900, 1467.7740],
[ 945.4811, 1833.3708, 997.2318, 1878.8607],
[1992.4447, 679.4969, 2134.6667, 835.8701],
[1098.5416, 1452.7799, 1429.1808, 1771.4460],
[1657.3193, 1405.5405, 1781.6273, 1574.6780],
[1443.8911, 1747.1544, 1739.0361, 2076.9724],
[1092.6003, 1165.3340, 1206.0881, 1383.8314],
[2466.4170, 1945.5931, 2555.1931, 2039.8368],
[2561.8508, 1616.2659, 2672.1033, 1742.2332],
[1894.4806, 907.9214, 2097.1875, 1182.6473],
[2321.5005, 1701.3344, 2368.3699, 1865.3914],
[2180.0781, 567.5969, 2344.6357, 763.4360],
[1845.7612, 668.6808, 2045.2688, 899.8501],
[1858.9216, 2145.7097, 1961.8870, 2273.5088],
[ 261.4607, 1314.0154, 396.9288, 1486.9498],
[2488.1682, 1585.2357, 2669.0178, 1794.9926],
[2696.9548, 936.0087, 2802.7961, 1025.2294],
[1593.6837, 1489.8641, 1720.3124, 1627.8135],
[2517.9468, 857.1713, 2567.1125, 929.4335],
[1943.2167, 636.3422, 2151.4419, 853.8924],
[2143.5664, 1100.0521, 2308.1570, 1290.7125],
[2140.9231, 1947.9692, 2238.6956, 2000.6249],
[1461.6316, 2105.2593, 1559.7675, 2189.0264],
[2114.0781, 374.8153, 2222.8838, 559.9851],
[2350.5320, 726.5779, 2466.8140, 878.2617]]),
'labels': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1]),
'scores': tensor([0.9916, 0.9841, 0.9669, 0.9337, 0.9118, 0.7729, 0.7202, 0.7193, 0.6928,
0.6872, 0.6690, 0.5913, 0.4877, 0.4683, 0.3781, 0.3327, 0.3164, 0.2364,
0.1696, 0.1692, 0.1502, 0.1365, 0.1316, 0.1171, 0.1119, 0.1094, 0.1041,
0.0865, 0.0853, 0.0835, 0.0822, 0.0816, 0.0797, 0.0796, 0.0788, 0.0780,
0.0757, 0.0736, 0.0736, 0.0689, 0.0681, 0.0644, 0.0642, 0.0630, 0.0612,
0.0598, 0.0563, 0.0531, 0.0525, 0.0522]),
'masks': tensor([[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
...,
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]]])}]