Я пытался прочитать модель tflite и вытащить все параметры слоев.
Мои шаги:
- Я сгенерировал представление модели flatbuffers, запустив (пожалуйста, создайте flatc раньше):
flatc -python tensorflow/tensorflow/lite/schema/schema.fbs
Результатом является папка tflite/
, которая содержит файлы описания слоя (*.py
) и некоторые утилитарные файлы.
- Я успешно загрузил модель:
в случае ошибки импорта: установите PYTHONPATH так, чтобы он указывал на папку, в которой tflite / равен
from tflite.Model import Model
def read_tflite_model(file):
buf = open(file, "rb").read()
buf = bytearray(buf)
model = Model.GetRootAsModel(buf, 0)
return model
- Я частично извлек параметры модели и узла и сложил их в итерации по узлам:
Модельная часть:
def print_model_info(model):
version = model.Version()
print("Model version:", version)
description = model.Description().decode('utf-8')
print("Description:", description)
subgraph_len = model.SubgraphsLength()
print("Subgraph length:", subgraph_len)
Узловая часть:
def print_nodes_info(model):
# what does this 0 mean? should it always be zero?
subgraph = model.Subgraphs(0)
operators_len = subgraph.OperatorsLength()
print('Operators length:', operators_len)
from collections import deque
nodes = deque(subgraph.InputsAsNumpy())
STEP_N = 0
MAX_STEPS = operators_len
print("Nodes info:")
while len(nodes) != 0 and STEP_N <= MAX_STEPS:
print("MAX_STEPS={} STEP_N={}".format(MAX_STEPS, STEP_N))
print("-" * 60)
node_id = nodes.pop()
print("Node id:", node_id)
tensor = subgraph.Tensors(node_id)
print("Node name:", tensor.Name().decode('utf-8'))
print("Node shape:", tensor.ShapeAsNumpy())
# which type is it? what does it mean?
type_of_tensor = tensor.Type()
print("Tensor type:", type_of_tensor)
quantization = tensor.Quantization()
min = quantization.MinAsNumpy()
max = quantization.MaxAsNumpy()
scale = quantization.ScaleAsNumpy()
zero_point = quantization.ZeroPointAsNumpy()
print("Quantization: ({}, {}), s={}, z={}".format(min, max, scale, zero_point))
# I do not understand it again. what is j, that I set to 0 here?
operator = subgraph.Operators(0)
for i in operator.OutputsAsNumpy():
nodes.appendleft(i)
STEP_N += 1
print("-"*60)
Пожалуйста, укажите мне на документацию или пример использования этого API.
Мои проблемы:
Я не могу получить документацию по этому API
Итерации по объектам Tensor для меня невозможны, так как в них нет методов Inputs и Outputs. + subgraph.Operators(j=0)
Я не понимаю, что здесь означает j. Из-за этого мой цикл проходит через два узла: ввод (один раз) и следующий снова и снова.
Итерации по объектам Operator безусловно возможны:
Здесь мы перебираем их все, но я не могу понять, как отобразить Operator и Tensor.
def print_in_out_info_of_all_operators(model):
# what does this 0 mean? should it always be zero?
subgraph = model.Subgraphs(0)
for i in range(subgraph.OperatorsLength()):
operator = subgraph.Operators(i)
print('Outputs', operator.OutputsAsNumpy())
print('Inputs', operator.InputsAsNumpy())
Я не понимаю, как извлечь параметры из объекта Operator. Метод BuiltinOptions дает мне объект Table, который я не знаю, на что отображать.
subgraph = model.Subgraphs(0)
Что означает этот 0? должен ли он всегда быть нулевым? очевидно нет, но что это? Идентификатор подграфа? Если так - я счастлив. Если нет, пожалуйста, попробуйте объяснить это.