Алгоритм отслеживания объектов обнаруживает объекты более одного раза - PullRequest
0 голосов
/ 07 мая 2020

Я разработал систему подсчета объектов, используя EfficientDet и эту ссылку. Она нормально работает с каждым кадром. Но мне нужен высокий FPS. Итак, я запускаю модель обнаружения объектов не в каждом кадре. Я запускаю его в каждом 5-м кадре. И реализую отслеживание объекта между 5 кадрами. В результате я вижу, что объект помечен 5 раз. Если я запускаю модель не в каждом 5-м кадре, а в каждом 10-м кадре, я вижу, что объект помечается 10 раз. Я имею в виду, мне нужен один идентификатор на кадр. Но я вижу, что 5 накапливаются за 5 кадров, сбрасываются, а затем повторяются для следующего пакета. Также меняются координаты ограничивающей рамки. Я знаю, это звучит странно. Итак, вот мой результат.

А это функция.

def saved_model_video(self, video_path: Text, output_video: Text, **kwargs):
    """Perform video inference for the given saved model."""
    import cv2  # pylint: disable=g-import-not-at-top

    driver = inference.ServingDriver(
        self.model_name,
        self.ckpt_path,
        batch_size=1,
        use_xla=self.use_xla,
        model_params=self.model_config.as_dict())
    driver.load(self.saved_model_dir)

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
      print('Error opening input video: {}'.format(video_path))

    W = None
    H = None
    # instantiate our centroid tracker, then initialize a list to store
    # each of our dlib correlation trackers, followed by a dictionary to
    # map each unique object ID to a TrackableObject
    ct = CentroidTracker(maxDisappeared=3, maxDistance=50)
    trackers = []
    trackableObjects = {}
    totalFrames = 0
    totalDown = 0
    totalUp = 0
    fps = FPS().start()
    label_id_mapping=None
    disable_pyfun=True

    while cap.isOpened():
      # Capture frame-by-frame
      ret, frame = cap.read()
      if not ret:
        break
      rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
      (H, W) =frame.shape[:2]  
      #status = "Waiting"
      rects = []
      if totalFrames % 3 == 0:
        status = "Detecting"
        trackers = []
        raw_frames = [np.array(frame)]
        detections_bs = driver.serve_images(raw_frames)
        for i in np.arange(0, detections_bs.shape[1]):

            prediction = detections_bs[0][i]               
            boxes = prediction[1:5]
            classes = prediction[6].astype(int)
            scores = prediction[5]
            if not disable_pyfun:

                boxes[:, [0, 1, 2, 3]] = boxes[:, [1, 0, 3, 2]]
            label_id_mapping = label_id_mapping or inference.coco_id_mapping
            (y, x, height, width) = boxes.astype("int")
            (startX, startY, endX, endY) = (x, y, x+width, y+height)
            tracker = dlib.correlation_tracker()
            rect = dlib.rectangle(startX, startY, endX, endY)
            tracker.start_track(raw_frames[0], rect)
            # add the tracker to our list of trackers so we can
            # utilize it during skip frames
            trackers.append(tracker)
      else:
        for tracker in trackers:
            # set the status of our system to be 'tracking' rather
            # than 'waiting' or 'detecting'
            #status = "Tracking"
            # update the tracker and grab the updated position
            tracker.update(raw_frames[0])
            pos = tracker.get_position()
            # unpack the position object
            startX = int(pos.left())
            startY = int(pos.top())
            endX = int(pos.right())
            endY = int(pos.bottom())
            # add the bounding box coordinates to the rectangles list
            rects.append((startX, startY, endX, endY))

      cv2.line(raw_frames[0], (0, H // 2), (W, H // 2), (0, 255, 255), 2)
      objects = ct.update(rects)

    # loop over the tracked objects
      for (objectID, centroid) in objects.items():
        to = trackableObjects.get(objectID, None)

        # if there is no existing trackable object, create one
        if to is None:
            to = TrackableObject(objectID, centroid)

        # otherwise, there is a trackable object so we can utilize it
        # to determine direction
        else:
            # the difference between the y-coordinate of the *current*
            # centroid and the mean of *previous* centroids will tell
            # us in which direction the object is moving (negative for
            # 'up' and positive for 'down')
            y = [c[1] for c in to.centroids]
            direction = centroid[1] - np.mean(y)
            to.centroids.append(centroid)
            if not to.counted:
                # if the direction is negative (indicating the object
                # is moving up) AND the centroid is above the center
                # line, count the object
                if direction < 0 and centroid[1] < H // 2:
                    totalUp += 1
                    to.counted = True

                # if the direction is positive (indicating the object
                # is moving down) AND the centroid is below the
                # center line, count the object
                elif direction > 0 and centroid[1] > H // 2:
                    totalDown += 1
                    to.counted = True

        trackableObjects[objectID] = to
        text = "ID {}".format(objectID)
        cv2.putText(raw_frames[0], text, (centroid[0] - 10, centroid[1] - 10),
            cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
        cv2.circle(raw_frames[0], (centroid[0], centroid[1]), 4, (0, 255, 0), -1)


      info = [
        ("Up", totalUp),
        ("Down", totalDown),
        #("Status", status),
        ]
      for (i, (k, v)) in enumerate(info):
        text = "{}: {}".format(k, v)
        cv2.putText(raw_frames[0], text, (10, H - ((i * 40) + 20)),
            cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), 2) 
      new_frame = driver.visualize(raw_frames[0], detections_bs[0], **kwargs)
      cv2.imshow('Frame',new_frame)
        # Press Q on keyboard to  exit
      if cv2.waitKey(1) & 0xFF == ord('q'):
        break
      totalFrames += 1
      fps.update()

# stop the timer and display FPS information
    fps.stop()
    print(totalUp)
    print(totalDown)
    print("[INFO] elapsed time: {:.2f}".format(fps.elapsed()))
    print("[INFO] approx. FPS: {:.2f}".format(fps.fps()))
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...