Как сделать вывод модели TFLite на видеовход - PullRequest
1 голос
/ 28 мая 2020

Я пытаюсь протестировать мою экспортированную модель Mobilenet v2 SSDLite (https://drive.google.com/open?id=1htyBE6R62yVCV8v-9muEJ_lGmoPxQMmJ) с видео. Затем я нашел ответ здесь , я где-то модифицирую, чтобы адаптировать свою модель:

import cv2
from PIL import Image
import numpy as np
import tensorflow as tf

def read_tensor_from_readed_frame(frame, input_height=300, input_width=300,
        input_mean=128, input_std=128):
  output_name = "normalized"
  # float_caster = tf.cast(frame, tf.float32)
  float_caster = tf.cast(frame, tf.uint8)
  dims_expander = tf.expand_dims(float_caster, 0);
  resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
  normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
  sess = tf.Session()
  result = sess.run(normalized)
  return result

def load_labels(label_file):
  label = []
  proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines()
  for l in proto_as_ascii_lines:
    label.append(l.rstrip())
  return label

def VideoSrcInit(paath):
    cap = cv2.VideoCapture(paath)
    flag, image = cap.read()
    if flag:
        print("Valid Video Path. Lets move to detection!")
    else:
        raise ValueError("Video Initialization Failed. Please make sure video path is valid.")
    return cap

def main():
  Labels_Path = "C:/MachineLearning/CV/coco-labelmap.txt"
  Model_Path = "C:/MachineLearning/CV/previous_float_model_converted_from_ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tflite"
  input_path = "C:/MachineLearning/CV/Object_Tracking/video2.mp4"

  ##Loading labels
  labels = load_labels(Labels_Path)

  ##Load tflite model and allocate tensors
  interpreter = tf.lite.Interpreter(model_path=Model_Path)
  interpreter.allocate_tensors()
  # Get input and output tensors.
  input_details = interpreter.get_input_details()
  output_details = interpreter.get_output_details()

  input_shape = input_details[0]['shape']

  ##Read video
  cap = VideoSrcInit(input_path)

  while True:
    ok, cv_image = cap.read()
    if not ok:
      break

    ##Converting the readed frame to RGB as opencv reads frame in BGR
    image = Image.fromarray(cv_image).convert('RGB')

    ##Converting image into tensor
    image_tensor = read_tensor_from_readed_frame(image ,300, 300)

    ##Test model
    interpreter.set_tensor(input_details[0]['index'], image_tensor)
    interpreter.invoke()
    output_data = interpreter.get_tensor(output_details[0]['index'])

    ## You need to check the output of the output_data variable and
    ## map it on the frame in order to draw the bounding boxes.


    cv2.namedWindow("cv_image", cv2.WINDOW_NORMAL)
    cv2.imshow("cv_image",cv_image)

    ##Use p to pause the video and use q to termiate the program
    key = cv2.waitKey(10) & 0xFF
    if key == ord("q"):
      break
    elif key == ord("p"):
      cv2.waitKey(0)
      continue
  cap.release()

if __name__ == '__main__':
  main()

Когда я запускаю этот scrpit на моей tflite-модели, FPS очень-очень медленный, почти все еще, поэтому что не так со скриптом?

1 Ответ

1 голос
/ 31 мая 2020

Я решаю это сам , это сценарий:

import numpy as np
import tensorflow as tf
import cv2
import time
print(tf.__version__)

Model_Path = "C:/MachineLearning/CV/uint8_dequantized_model_converted_from_exported_model.tflite"
Video_path = "C:/MachineLearning/CV/Object_Tracking/video2.mp4"

interpreter = tf.lite.Interpreter(model_path=Model_Path)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

class_names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane','bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant ', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',
'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', ' cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', ' cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush']

cap = cv2.VideoCapture(Video_path)
ok, frame_image = cap.read()
original_image_height, original_image_width, _ = frame_image.shape
thickness = original_image_height // 500  
fontsize = original_image_height / 1500
print(thickness)
print(fontsize)

while True:
    ok, frame_image = cap.read()
    if not ok:
        break

    model_interpreter_start_time = time.time()
    resize_img = cv2.resize(frame_image, (300, 300), interpolation=cv2.INTER_CUBIC)
    reshape_image = resize_img.reshape(300, 300, 3)
    image_np_expanded = np.expand_dims(reshape_image, axis=0)
    image_np_expanded = image_np_expanded.astype('uint8')  # float32

    interpreter.set_tensor(input_details[0]['index'], image_np_expanded) 
    interpreter.invoke()

    output_data = interpreter.get_tensor(output_details[0]['index'])
    output_data_1 = interpreter.get_tensor(output_details[1]['index']) 
    output_data_2 = interpreter.get_tensor(output_details[2]['index'])
    output_data_3 = interpreter.get_tensor(output_details[3]['index'])  
    each_interpreter_time = time.time() - model_interpreter_start_time

    for i in range(len(output_data_1[0])):
        confidence_threshold = output_data_2[0][i]
        if confidence_threshold > 0.3:
            label = "{}: {:.2f}% ".format(class_names[int(output_data_1[0][i])], output_data_2[0][i] * 100) 
            label2 = "inference time : {:.3f}s" .format(each_interpreter_time)
            left_up_corner = (int(output_data[0][i][1]*original_image_width), int(output_data[0][i][0]*original_image_height))
            left_up_corner_higher = (int(output_data[0][i][1]*original_image_width), int(output_data[0][i][0]*original_image_height)-20)
            right_down_corner = (int(output_data[0][i][3]*original_image_width), int(output_data[0][i][2]*original_image_height))
            cv2.rectangle(frame_image, left_up_corner_higher, right_down_corner, (0, 255, 0), thickness)
            cv2.putText(frame_image, label, left_up_corner_higher, cv2.FONT_HERSHEY_DUPLEX, fontsize, (255, 255, 255), thickness=thickness)
            cv2.putText(frame_image, label2, (30, 30), cv2.FONT_HERSHEY_DUPLEX, fontsize, (255, 255, 255), thickness=thickness)
    cv2.namedWindow('detect_result', cv2.WINDOW_NORMAL)
    # cv2.resizeWindow('detect_result', 800, 600)
    cv2.imshow("detect_result", frame_image)

    key = cv2.waitKey(10) & 0xFF
    if key == ord("q"):
        break
    elif key == 32:
        cv2.waitKey(0)
        continue
cap.release()
cv2.destroyAllWindows()

, но вывод все еще медленный, потому что операции tflite оптимизированы для мобильных устройств, а не для настольных компьютеров.

...