Как выполнить обнаружение целевого объекта с помощью tenorflow (python)? - PullRequest
0 голосов
/ 29 мая 2020

У меня есть полный код обнаружения объектов с использованием тензорного потока (FasterRCNN), но теперь я хочу выполнить обнаружение целевых объектов (одного или нескольких). Для этой цели я создал радиокнопки для выбора объекта, значения которого предоставляются кодировке обнаружения объекта. Проблема в том, что когда я выбираю любой другой объект во время выполнения кода, обнаружение не обновляется с моим новым значением. Однако при перезапуске программы отображается выбранное обнаружение.

    # -*- coding: utf-8 -*-
    import sys
    import numpy as np

    from PyQt5.QtWidgets import (QWidget, QProgressBar, 
        QPushButton, QApplication)
    from PyQt5.QtCore import QBasicTimer

    import cv2
    import os
    from PyQt5 import QtCore
    from PyQt5.QtCore  import pyqtSlot
    from PyQt5.QtGui import QImage , QPixmap
    from PyQt5.QtWidgets import QDialog , QApplication
    from PyQt5.uic import loadUi

    from PyQt5 import QtCore, QtGui, QtWidgets

    from PyQt5.QtCore import QDir, Qt, QUrl
    from PyQt5.QtMultimedia import QMediaContent, QMediaPlayer
    from PyQt5.QtMultimediaWidgets import QVideoWidget
    from PyQt5.QtWidgets import (QApplication, QFileDialog, QHBoxLayout, QLabel,
            QPushButton, QSizePolicy, QSlider, QStyle, QVBoxLayout, QWidget)
    from PyQt5.QtWidgets import QMainWindow,QWidget, QPushButton, QAction
    from PyQt5.QtGui import QIcon
    import qimage2ndarray

    import tensorflow as tf
    import cv2

    import os, sys, time
    # Object detection imports
    from utils import backbone
    from api import object_counting_api2

    from datetime import datetime
    from colorama import Fore, Back, Style 

    import csv

    import numpy as np
    from utils import visualization as vis_util

    class mycode(QMainWindow):
            def __init__(self):
                    super(mycode,self).__init__()

                    loadUi("untitled5.ui",self)

                    content_widget = QtWidgets.QWidget()
                    self.scrollArea.setWidget(content_widget)
                    self._lay = QtWidgets.QVBoxLayout(content_widget)

                    #self.SHOW.clicked.connect(self.onClicked)
                    self.TEXT.setText("Kindly 'Select' a file to start counting.")
                    #self.CAPTURE.clicked.connect(self.CaptureClicked)
                    self.actionOpen.setStatusTip('Open movie')
                    self.actionOpen.triggered.connect(self.onClicked)

                    self.btn1_2.clicked.connect(self.btn1_click)


            def btn1_click(self):
                    selected_val = []
                    if self.checkbox.isChecked():
                        selected_val.append(self.checkbox.text())
                    if self.checkbox1_2.isChecked():
                        selected_val.append(self.checkbox1_2.text())
                    if self.checkbox2_2.isChecked():
                        selected_val.append(self.checkbox2_2.text())

                    self.val = ""
                    for i in selected_val:            
                        if (i==selected_val[-1]):
                                self.val = self.val + i + ""
                        else:
                            self.val = self.val + i + ", "

                    self.label1_2.setText(self.val)
                    print(self.val)


            @pyqtSlot()
            def onClicked(self):
                    self.TEXT.setText('Dislaying Vehicle Detection and Counting')
                    fileName, _ = QFileDialog.getOpenFileName(self, "Open Movie", QDir.homePath())
                    detection_graph, category_index = backbone.set_model('inference_graph', 'labelmap1.pbtxt')
                    is_color_recognition_enabled = 0
                    targeted_object = self.val

                    cap =cv2.VideoCapture(fileName)

                    total_passed_vehicle = 0

                    direction = "waiting..."
                    size = "waiting..."
                    color = "waiting..."
                    the_result = "..."
                    width_heigh_taken = True
                    height = 0
                    width = 0
                    with detection_graph.as_default():
                      with tf.Session(graph=detection_graph) as sess:
                        # Definite input and output Tensors for detection_graph
                        image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')

                        # Each box represents a part of the image where a particular object was detected.
                        detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')

                        # Each score represent how level of confidence for each of the objects.
                        # Score is shown on the result image, together with the class label.
                        detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
                        detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
                        num_detections = detection_graph.get_tensor_by_name('num_detections:0')

                        # for all the frames that are extracted from input video
                        while(cap.isOpened()):
                            ret, frame = cap.read()                

                            if not  ret:
                                print("end of the video file...")
                                break

                            input_frame = frame

                            # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
                            image_np_expanded = np.expand_dims(input_frame, axis=0)

                            # Actual detection.
                            (boxes, scores, classes, num) = sess.run(
                                [detection_boxes, detection_scores, detection_classes, num_detections],
                                feed_dict={image_tensor: image_np_expanded})

                            # insert information text to video frame
                            font = cv2.FONT_HERSHEY_SIMPLEX

                            # Visualization of the results of a detection.        
                            counter, csv_line, the_result = vis_util.visualize_boxes_and_labels_on_image_array(cap.get(1),
                                                                                                                  input_frame,
                                                                                                                  1,
                                                                                                                  is_color_recognition_enabled,
                                                                                                                  np.squeeze(boxes),
                                                                                                                  np.squeeze(classes).astype(np.int32),
                                                                                                                  np.squeeze(scores),
                                                                                                                  category_index,
                                                                                                                  targeted_objects=targeted_object,
                                                                                                                  use_normalized_coordinates=True,
                                                                                                                  line_thickness=4)
                            if(len(the_result) == 0):
                                cv2.putText(input_frame, "...", (10, 35), font, 0.8, (0,255,255),2,cv2.FONT_HERSHEY_SIMPLEX)                       
                            else:
                                cv2.putText(input_frame, the_result, (10, 35), font, 0.8, (0,0,0),2,cv2.FONT_HERSHEY_SIMPLEX)


                            self.displayImage(input_frame,1)

                            if cv2.waitKey(1) & 0xFF == ord('q'):
                                    break

                        cap.release()
                        cv2.destroyAllWindows()

            def displayImage(self,img,window=1):
                    qformat=QImage.Format_Indexed8
                    if len(img.shape)==3:
                            if(img.shape[2])==4:
                                    qformat=QImage.Format_RGBA888
                            else:
                                    qformat=QImage.Format_RGB888
                    img = QImage(img,img.shape[1],img.shape[0],qformat)
                    img = img.rgbSwapped()
                    self.imgLabel.setPixmap(QPixmap.fromImage(img))
                    self.imgLabel.setAlignment(QtCore.Qt.AlignHCenter | QtCore.Qt.AlignVCenter)



    app =  QApplication(sys.argv)
    window=mycode()
    window.show()
    sys.exit(app.exec_())
...