Как читать параметры слоев модели .tflite в python - PullRequest
1 голос
/ 21 марта 2019

Я пытался прочитать модель tflite и вытащить все параметры слоев.

Мои шаги:

  1. Я сгенерировал представление модели flatbuffers, запустив (пожалуйста, создайте flatc раньше):

flatc -python tensorflow/tensorflow/lite/schema/schema.fbs

Результатом является папка tflite/, которая содержит файлы описания слоя (*.py) и некоторые утилитарные файлы.

  1. Я успешно загрузил модель:

в случае ошибки импорта: установите PYTHONPATH так, чтобы он указывал на папку, в которой tflite / равен

from tflite.Model import Model
def read_tflite_model(file):
    buf = open(file, "rb").read()
    buf = bytearray(buf)
    model = Model.GetRootAsModel(buf, 0)
    return model
  1. Я частично извлек параметры модели и узла и сложил их в итерации по узлам:

Модельная часть:

def print_model_info(model):
        version = model.Version()
        print("Model version:", version)
        description = model.Description().decode('utf-8')
        print("Description:", description)
        subgraph_len = model.SubgraphsLength()
        print("Subgraph length:", subgraph_len)

Узловая часть:

def print_nodes_info(model):
    # what does this 0 mean? should it always be zero?
    subgraph = model.Subgraphs(0)
    operators_len = subgraph.OperatorsLength()
    print('Operators length:', operators_len)

    from collections import deque
    nodes = deque(subgraph.InputsAsNumpy())

    STEP_N = 0
    MAX_STEPS = operators_len
    print("Nodes info:")
    while len(nodes) != 0 and STEP_N <= MAX_STEPS:
        print("MAX_STEPS={} STEP_N={}".format(MAX_STEPS, STEP_N))
        print("-" * 60)

        node_id = nodes.pop()
        print("Node id:", node_id)

        tensor = subgraph.Tensors(node_id)
        print("Node name:", tensor.Name().decode('utf-8'))
        print("Node shape:", tensor.ShapeAsNumpy())

        # which type is it? what does it mean?
        type_of_tensor = tensor.Type()
        print("Tensor type:", type_of_tensor)

        quantization = tensor.Quantization()
        min = quantization.MinAsNumpy()
        max = quantization.MaxAsNumpy()
        scale = quantization.ScaleAsNumpy()
        zero_point = quantization.ZeroPointAsNumpy()
        print("Quantization: ({}, {}), s={}, z={}".format(min, max, scale, zero_point))

        # I do not understand it again. what is j, that I set to 0 here?
        operator = subgraph.Operators(0)
        for i in operator.OutputsAsNumpy():
            nodes.appendleft(i)

        STEP_N += 1

    print("-"*60)

Пожалуйста, укажите мне на документацию или пример использования этого API.

Мои проблемы:

  1. Я не могу получить документацию по этому API

  2. Итерации по объектам Tensor для меня невозможны, так как в них нет методов Inputs и Outputs. + subgraph.Operators(j=0) Я не понимаю, что здесь означает j. Из-за этого мой цикл проходит через два узла: ввод (один раз) и следующий снова и снова.

  3. Итерации по объектам Operator безусловно возможны:

Здесь мы перебираем их все, но я не могу понять, как отобразить Operator и Tensor.

def print_in_out_info_of_all_operators(model):
    # what does this 0 mean? should it always be zero?
    subgraph = model.Subgraphs(0)
    for i in range(subgraph.OperatorsLength()):
        operator = subgraph.Operators(i)
        print('Outputs', operator.OutputsAsNumpy())
        print('Inputs', operator.InputsAsNumpy())
  1. Я не понимаю, как извлечь параметры из объекта Operator. Метод BuiltinOptions дает мне объект Table, который я не знаю, на что отображать.

  2. subgraph = model.Subgraphs(0) Что означает этот 0? должен ли он всегда быть нулевым? очевидно нет, но что это? Идентификатор подграфа? Если так - я счастлив. Если нет, пожалуйста, попробуйте объяснить это.

...