Как перечислить все используемые операции в Tensorflow SavedModel? - PullRequest
10 голосов
/ 10 февраля 2020

Если я сохраню свою модель с помощью функции tensorflow.saved_model.save в формате SavedModel, как я могу узнать, какие операции Tensorflow используются в этой модели впоследствии. Поскольку модель может быть восстановлена, эти операции хранятся в графике, я думаю, в файле saved_model.pb. Если я загружаю этот protobuf (то есть не всю модель), библиотечная часть protobuf перечисляет их, но это пока не задокументировано и не помечено как экспериментальная функция. Модели, созданные в Tensorflow 1.x, не будут иметь этой части.

Так что же является быстрым и надежным способом получения списка используемых операций (например, MatchingFiles или WriteFile) из модели в формате SavedModel?

Прямо сейчас я могу заморозить все это, как это делает tensorflowjs-converter. Как они также проверяют на поддерживаемые операции. В настоящее время это не работает, когда в модели присутствует LSTM, см. здесь . Есть ли лучший способ сделать это, поскольку Ops определенно там?

Пример модели:

class FileReader(tf.Module):

@tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
def read_disk(self, file_name):
    input_scalar = tf.reshape(file_name, [])
    output = tf.io.read_file(input_scalar)
    return tf.stack([output], name='content')

file_reader = FileReader()

tf.saved_model.save(file_reader, 'file_reader')

Ожидается при выводе всех Ops, содержащих в этом случае, по крайней мере:

1 Ответ

1 голос
/ 14 февраля 2020

Если saved_model.pb - это сообщение SavedModel protobuf, то вы получаете операции непосредственно оттуда. Допустим, мы создаем модель следующим образом:

import tensorflow as tf

class FileReader(tf.Module):
    @tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
    def read_disk(self, file_name):
        input_scalar = tf.reshape(file_name, [])
        output = tf.io.read_file(input_scalar)
        return tf.stack([output], name='content')

file_reader = FileReader()
tf.saved_model.save(file_reader, 'tmp')

Теперь мы можем найти операции, используемые этой моделью, следующим образом:

from tensorflow.core.protobuf.saved_model_pb2 import SavedModel

saved_model = SavedModel()
with open('tmp/saved_model.pb', 'rb') as f:
    saved_model.ParseFromString(f.read())
model_op_names = set()
# Iterate over every metagraph in case there is more than one
for meta_graph in saved_model.meta_graphs:
    # Add operations in the graph definition
    model_op_names.update(node.op for node in meta_graph.graph_def.node)
    # Go through the functions in the graph definition
    for func in meta_graph.graph_def.library.function:
        # Add operations in each function
        model_op_names.update(node.op for node in func.node_def)
# Convert to list, sorted if you want
model_op_names = sorted(model_op_names)
print(*model_op_names, sep='\n')
# Const
# Identity
# MergeV2Checkpoints
# NoOp
# Pack
# PartitionedCall
# Placeholder
# ReadFile
# Reshape
# RestoreV2
# SaveV2
# ShardedFilename
# StatefulPartitionedCall
# StringJoin
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...