Ошибка чтения замороженного вывода в Tensorflow 2.0 - PullRequest
0 голосов
/ 01 апреля 2020

Я пытаюсь прочитать определения графиков и параметры обучения модели тензорного потока, которую я скачал из сада тензорного потока. Цель состоит в том, чтобы иметь возможность найти параметры и использовать их для преобразования модели в модель tflite. Ниже приведен код, который я использую для чтения определений графиков, которые я скопировал онлайн с некоторыми небольшими изменениями:

import tensorflow as tf
import argparse

parse = argparse.ArgumentParser()
parse.add_argument('-m','--model',required=True,help='Absolute path to graph file')
args=vars(parse.parse_args())

def load_graph(frozen_graph_filename):
    with tf.compat.v1.gfile.GFile(frozen_graph_filename, "rb") as f:
        graph_def = tf.compat.v1.GraphDef
        graph_def.ParseFromString(f.read())

    with tf.compat.v1.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name="prefix")
    return graph


if __name__ == '__main__':
    info = args['model']
    graph = load_graph(info)
    for op in graph.get_operations():
        abc = graph.get_tensor_by_name(op.name + ":0")
        print(abc)

После выполнения этого кода я получаю эту ошибку -

Traceback (most recent call last):
  File "../info.py", line 20, in <module>
    graph = load_graph(info)
  File "../info.py", line 11, in load_graph
    graph_def.ParseFromString(f.read())
TypeError: descriptor 'ParseFromString' requires a 'google.protobuf.pyext._message.CMessage' object but received a 'bytes'

Что может быть проблемой. Я проверил в Интернете другие источники, и этот код, кажется, работает все время, поэтому я предполагаю, что что-то сделал не так.

...