Фон
Итак, у меня есть модель на основе ансамбля, и я хочу опубликовать обработку меток, полученных от N детекторов.Для этого я хочу использовать функцию groupby для группировки задержек в каждом кадре по функции IOU, которую я построил.Это поможет мне принять решение для каждой группы обнаружения и избежать многократного обнаружения в одной и той же области кадра.
Q: Какое правильное возвращаемое значение для функции groupby?
Документация Pandas groupby гласит:
by : mapping, function, label, or list of labels
Used to determine the groups for the groupby.
If ``by`` is a function, it's called on each value of the object's
index. If a dict or Series is passed, the Series or dict VALUES
will be used to determine the groups (the Series' values are first
aligned; see ``.align()`` method). If an ndarray is passed, the
values are used as-is determine the groups. A label or list of
labels may be passed to group by the columns in ``self``. Notice
that a tuple is interpreted a (single) key.
Пример моих данных:
xmin ymin xmax ... confidence class frame_idx
25092 0.740968 0.379832 0.837317 ... 0.931677 2.0 988.0
25093 0.342227 0.272692 0.421934 ... 0.904239 2.0 988.0
25094 0.624577 0.333927 0.739960 ... 0.854978 2.0 988.0
25095 0.229680 0.260676 0.272042 ... 0.708415 2.0 988.0
25096 0.420753 0.291226 0.514133 ... 0.610117 2.0 988.0
25097 0.061343 0.265538 0.110653 ... 0.369827 2.0 988.0
25098 0.486258 0.711575 0.558676 ... 0.173804 2.0 988.0
25099 0.069983 0.596780 0.571130 ... 0.145625 2.0 988.0
25100 0.756423 0.693024 0.818869 ... 0.126882 1.0 988.0
25101 0.448371 0.000000 0.628947 ... 0.124258 1.0 988.0
25102 0.664900 0.808117 0.728619 ... 0.119267 2.0 988.0
25103 0.000000 0.776122 0.407657 ... 0.119260 2.0 988.0
25104 0.796306 0.878808 0.844132 ... 0.115530 1.0 988.0
25105 0.006830 0.000000 0.045303 ... 0.110469 1.0 988.0
РЕДАКТИРОВАТЬ
Я добавил свою функцию обнаружения групп, в настоящее время она возвращает определения и список индексов, сгруппированных в список
def group_detections(dets, thresh=0.5):
# NOTE!!! - We assume all dets are from the same class!!!
# thresh = IoU threshold for box elimination (lower - eliminate more boxes)
# initialize the list of picked groups
groups = []
# if there are no boxes, return an empty list
if dets.size == 0:
return dets, groups
# sort by descending quality
dets = dets[(-dets[:, 4]).argsort(), :]
# grab the coordinates of the bounding boxes
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
# compute the area of the bounding boxes and sort the bounding
# boxes by the bottom-right y-coordinate of the bounding box
area = (x2 - x1) * (y2 - y1)
# keep looping while some indexes still remain in the indexes
# list
all_idxs = np.arange(dets.shape[0])[::-1]
while len(all_idxs) > 0:
# grab the last index in the indexes list and add the
# index value to the list of picked indexes
i = all_idxs[-1]
# find the largest (x, y) coordinates for the start of
# the bounding box and the smallest (x, y) coordinates
# for the end of the bounding box
xx1 = np.maximum(x1[i], x1[all_idxs[:-1]])
yy1 = np.maximum(y1[i], y1[all_idxs[:-1]])
xx2 = np.minimum(x2[i], x2[all_idxs[:-1]])
yy2 = np.minimum(y2[i], y2[all_idxs[:-1]])
# compute the width and height of the bounding box
w = np.maximum(0, xx2 - xx1)
h = np.maximum(0, yy2 - yy1)
# compute the ratio of overlap
overlap = (w * h) / area[all_idxs[:-1]]
# delete all indexes from the index list that have
group_idxs = [i] + list(all_idxs[np.where(overlap > thresh)[0]])
groups.append(group_idxs)
remove_idxs = np.concatenate(([len(all_idxs) - 1], np.where(overlap > thresh)[0]))
all_idxs = np.delete(all_idxs, remove_idxs)
return dets, groups
Пример вывода: dets:
array([[0.70187318, 0.37271658, 0.72741735, 0.4855186 , 0.37807396,
1. , 0. ],
[0.15145046, 0.28159913, 0.1774891 , 0.33394483, 0.30386543,
1. , 0. ],
[0.51533937, 0.27924138, 0.5292989 , 0.33969468, 0.19332546,
1. , 0. ],
[0.78109378, 0.87562817, 0.82961792, 1. , 0.13274428,
1. , 0. ],
[0.14983323, 0.28213197, 0.16887194, 0.33710539, 0.12764682,
1. , 0. ],
[0.14637044, 0.033903 , 0.26365086, 0.57686102, 0.12654687,
1. , 0. ],
[0.22914155, 0.26150814, 0.27148777, 0.31108513, 0.10923221,
1. , 0. ],
[0.18359725, 0.02795739, 0.23144765, 0.4838326 , 0.10619797,
1. , 0. ],
[0.22894984, 0.28633049, 0.24446395, 0.35711476, 0.10326774,
1. , 0. ],
[0.15507454, 0.2818917 , 0.18600157, 0.34016567, 0.10280589,
1. , 0. ],
[0.0071237 , 0. , 0.04561174, 0.08170507, 0.10274705,
1. , 0. ],
[0.74405503, 0.38016886, 0.8356874 , 0.4577626 , 0.10082066,
1. , 0. ]])
groups:
[[0], [1, 9, 4], [2], [3], [5, 8, 7, 6], [10], [11]]