Есть ли какой-нибудь способ загрузить файл .pb, созданный в tf v1, на tenorflow v2? - PullRequest
1 голос
/ 06 ноября 2019

Я пытаюсь загрузить файл .pb, который был создан в tf v1 на дистрибутиве tfv2, у меня вопрос: есть ли у версии 2 совместимость со старым pb?

Я уже пробовал несколько вещей, но никто из них не работал. Попытка загрузить pb-файл напрямую с помощью:

with tf.compat.v1.gfile.GFile("./saved_model.pb", "rb") as f:
    graph_def = tf.compat.v1.GraphDef()
    graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name="")

Результат, когда я запускаю приведенный выше код:

Traceback (most recent call last):
  File "read_tfv1_pb.py", line 7, in <module>
    graph_def.ParseFromString(f.read())
  File "D:\Anaconda3\envs\tf2\lib\site-packages\google\protobuf\message.py", line 187, in ParseFromString
    return self.MergeFromString(serialized)
  File "D:\Anaconda3\envs\tf2\lib\site-packages\google\protobuf\internal\python_message.py", line 1128, in MergeFromString
    if self._InternalParse(serialized, 0, length) != length:
  File "D:\Anaconda3\envs\tf2\lib\site-packages\google\protobuf\internal\python_message.py", line 1193, in InternalParse
    pos = field_decoder(buffer, new_pos, end, self, field_dict)
  File "D:\Anaconda3\envs\tf2\lib\site-packages\google\protobuf\internal\decoder.py", line 968, in _SkipFixed32
    raise _DecodeError('Truncated message.')
google.protobuf.message.DecodeError: Truncated message.

Если нет, есть ли способ, которым я могу сохранить весастарых pb и поместите их в новый экземпляр модели в tenorflow v2, чтобы применить обучение переносу / сохранению с новой структурой модели?

1 Ответ

0 голосов
/ 08 ноября 2019

Преобразовать его в tf.saved_model с кодом отсюда Преобразовать график прото (pb / pbtxt) в SavedModel для использования в TensorFlow Serving или Cloud ML Engine

Я простозаметил, что ваше .pb имя saved_model.pb, так что, возможно, это уже tf.saved_model. Если это так, вы можете загрузить его как

func = tf.saved_model.load('.').signatures["serving_default"] 
out = func( tf.constant(10,tf.float32) )
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...