Я уже некоторое время успешно использую TensorflowSharp с Faster RCNN;Тем не менее, я недавно изучил модель Retinanet, проверил, что она работает в python, и создал замороженный файл pb для использования с Tensorflow.Для FRCNN в репозитории TensorflowSharp GitHub есть пример, который показывает, как запустить / получить эту модель.Для Retinanet я попытался изменить код, но ничего не работает.У меня есть сводка моделей для Retinanet, с которой я пытался работать, но для меня не очевидно, что следует использовать.
Для FRCNN график запускается следующим образом:
var runner = m_session.GetRunner();
runner
.AddInput(m_graph["image_tensor"][0], tensor)
.Fetch(
m_graph["detection_boxes"][0],
m_graph["detection_scores"][0],
m_graph["detection_classes"][0],
m_graph["num_detections"][0]);
var output = runner.Run();
var boxes = (float[,,])output[0].GetValue(jagged: false);
var scores = (float[,])output[1].GetValue(jagged: false);
var classes = (float[,])output[2].GetValue(jagged: false);
var num = (float[])output[3].GetValue(jagged: false);
Из сводки модели для FRCNN очевидно, что представляют собой входные данные ("image_tensor") и выходные данные ("detection_boxes", "detection_scores", "treatment_classes" и "num_detections").Они не одинаковы для Retinanet (я пробовал), и я не могу понять, какими они должны быть.Часть «Fetch» в приведенном выше коде вызывает сбой, и я предполагаю, что это потому, что я неправильно понимаю имена узлов.
Я не буду вставлять всю сводку Retinanet здесь, но здесьэто первые несколько узлов:
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, None, None, 3 0
__________________________________________________________________________________________________
padding_conv1 (ZeroPadding2D) (None, None, None, 3 0 input_1[0][0]
__________________________________________________________________________________________________
conv1 (Conv2D) (None, None, None, 6 9408 padding_conv1[0][0]
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization) (None, None, None, 6 256 conv1[0][0]
__________________________________________________________________________________________________
conv1_relu (Activation) (None, None, None, 6 0 bn_conv1[0][0]
__________________________________________________________________________________________________
И вот несколько последних узлов:
__________________________________________________________________________________________________
anchors_0 (Anchors) (None, None, 4) 0 P3[0][0]
__________________________________________________________________________________________________
anchors_1 (Anchors) (None, None, 4) 0 P4[0][0]
__________________________________________________________________________________________________
anchors_2 (Anchors) (None, None, 4) 0 P5[0][0]
__________________________________________________________________________________________________
anchors_3 (Anchors) (None, None, 4) 0 P6[0][0]
__________________________________________________________________________________________________
anchors_4 (Anchors) (None, None, 4) 0 P7[0][0]
__________________________________________________________________________________________________
regression_submodel (Model) (None, None, 4) 2443300 P3[0][0]
P4[0][0]
P5[0][0]
P6[0][0]
P7[0][0]
__________________________________________________________________________________________________
anchors (Concatenate) (None, None, 4) 0 anchors_0[0][0]
anchors_1[0][0]
anchors_2[0][0]
anchors_3[0][0]
anchors_4[0][0]
__________________________________________________________________________________________________
regression (Concatenate) (None, None, 4) 0 regression_submodel[1][0]
regression_submodel[2][0]
regression_submodel[3][0]
regression_submodel[4][0]
regression_submodel[5][0]
__________________________________________________________________________________________________
boxes (RegressBoxes) (None, None, 4) 0 anchors[0][0]
regression[0][0]
__________________________________________________________________________________________________
classification_submodel (Model) (None, None, 1) 2381065 P3[0][0]
P4[0][0]
P5[0][0]
P6[0][0]
P7[0][0]
__________________________________________________________________________________________________
clipped_boxes (ClipBoxes) (None, None, 4) 0 input_1[0][0]
boxes[0][0]
__________________________________________________________________________________________________
classification (Concatenate) (None, None, 1) 0 classification_submodel[1][0]
classification_submodel[2][0]
classification_submodel[3][0]
classification_submodel[4][0]
classification_submodel[5][0]
__________________________________________________________________________________________________
filtered_detections (FilterDete [(None, 300, 4), (No 0 clipped_boxes[0][0]
classification[0][0]
==================================================================================================
Total params: 36,382,957
Trainable params: 36,276,717
Non-trainable params: 106,240
Любая помощь с выяснением, как исправить часть "Fetch", будет оченьприветствуется.
РЕДАКТИРОВАТЬ:
Чтобы углубиться в это, я нашел функцию python для печати имен операций из файла .pb.Делая это для файла FRPNN .pb, он четко дал имена выходных узлов, как видно ниже (только последние несколько строк из выходных данных функции python).
import/SecondStagePostprocessor/BatchMultiClassNonMaxSuppression/map/TensorArrayStack_4/TensorArrayGatherV3
import/SecondStagePostprocessor/ToFloat_1
import/add/y
import/add
import/detection_boxes
import/detection_scores
import/detection_classes
import/num_detections
Если япроделайте то же самое с файлом Retinanet .pb, но не совсем понятно, что это за выходные данные.Вот несколько последних строк из функции python.
import/filtered_detections/map/while/NextIteration_4
import/filtered_detections/map/while/Exit_2
import/filtered_detections/map/while/Exit_3
import/filtered_detections/map/while/Exit_4
import/filtered_detections/map/TensorArrayStack/TensorArraySizeV3
import/filtered_detections/map/TensorArrayStack/range/start
import/filtered_detections/map/TensorArrayStack/range/delta
import/filtered_detections/map/TensorArrayStack/range
import/filtered_detections/map/TensorArrayStack/TensorArrayGatherV3
import/filtered_detections/map/TensorArrayStack_1/TensorArraySizeV3
import/filtered_detections/map/TensorArrayStack_1/range/start
import/filtered_detections/map/TensorArrayStack_1/range/delta
import/filtered_detections/map/TensorArrayStack_1/range
import/filtered_detections/map/TensorArrayStack_1/TensorArrayGatherV3
import/filtered_detections/map/TensorArrayStack_2/TensorArraySizeV3
import/filtered_detections/map/TensorArrayStack_2/range/start
import/filtered_detections/map/TensorArrayStack_2/range/delta
import/filtered_detections/map/TensorArrayStack_2/range
import/filtered_detections/map/TensorArrayStack_2/TensorArrayGatherV3
Для справки вот функция Python, которую я использовал:
def printTensors(pb_file):
# read pb into graph_def
with tf.gfile.GFile(pb_file, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# import graph_def
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def)
# print operations
for op in graph.get_operations():
print(op.name)
Надеюсь, это поможет.