Введение: я новичок в машинном обучении, и мне и моему коллеге приходится реализовывать алгоритм обнаружения светофоров.Я скачал предварительно обученную модель (быстрее rcnn) и провел несколько тренировочных шагов (~ 10000).Теперь при использовании алгоритма обнаружения объектов из репозитория tenorflow git обнаружено несколько светофоров в одной области.
Я провел небольшое исследование и обнаружил функцию "tf.image.non_max_suppression", но не могу заставить ее работатькак задумано (честно говоря, я даже не могу заставить его работать).
Я предполагаю, что вы знаете пример кода обнаружения объектов tf, поэтому вы также знаете, что все поля возвращаются с использованием словаря (output_dict).
Чтобы "очистить" ящики, которые я использую:
selected_indices = tf.image.non_max_suppression(
boxes = output_dict['detection_boxes'],
scores = output_dict['detection_scores'],
max_output_size = 1,
iou_threshold = 0.5,
score_threshold = float('-inf'),
name = None)
Сначала я подумал, что могу использовать selected_indices в качестве нового списка ящиков, поэтому я попробовал это:
vis_util.visualize_boxes_and_labels_on_image_array(
image = image_np,
boxes = selected_indices,
classes = output_dict['detection_classes'],
scores = output_dict['detection_scores'],
category_index = category_index,
instance_masks = output_dict.get('detection_masks'),
use_normalized_coordinates = True)
но когда я заметил, что это не работает, я нашел необходимый метод: "tf.gather ()".Затем я запустил следующий код:
boxes = output_dict['detection_boxes']
selected_indices = tf.image.non_max_suppression(
boxes = boxes,
scores = output_dict['detection_scores'],
max_output_size = 1,
iou_threshold = 0.5,
score_threshold = float('-inf'),
name = None)
selected_boxes = tf.gather(boxes, selected_indices)
vis_util.visualize_boxes_and_labels_on_image_array(
image = image_np,
boxes = selected_boxes,
classes = output_dict['detection_classes'],
scores = output_dict['detection_scores'],
category_index = category_index,
instance_masks = output_dict.get('detection_masks'),
use_normalized_coordinates = True)
, но даже этот не работает.Я получаю AttributeError (у объекта 'Tensor' нет атрибута 'tolist') в visualization_utils.py в строке 689.