Tensorflow: создайте tf.NodeDef () и установите атрибуты - PullRequest
7 голосов
/ 27 мая 2019

Я пытаюсь создать новый узел и установить его атрибуты.

Например, при печати одного из узлов графа я вижу, что его атрибуты:

attr {
  key: "T"
  value {
    type: DT_FLOAT
  }
}

Я могу создатьтакой узел, как:

node = tf.NodeDef(name='MyConstTensor', op='Const',
                   attr={'value': tf.AttrValue(tensor=tensor_proto),
                         'dtype': tf.AttrValue(type=dt)})

но как добавить key: "T" атрибут?то есть что должно быть внутри tf.AttrValue в этом случае?

Глядя на attr_value.proto Я пробовал:

node = tf.NodeDef()
node.name = 'MySub'
node.op = 'Sub'
node.input.extend(['MyConstTensor', 'conv2'])
node.attr["key"].s = 'T' # TypeError: 'T' has type str, but expected one of: bytes

ОБНОВЛЕНИЕ:

Я понял, что в Tensorflow это должно быть записано так:

node.attr["T"].type = b'float32'

Но это выдает ошибку:

TypeError: b'float32 'имеет типбайт, но ожидается один из: int, long

И я не уверен, какое значение int соответствует float32.

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/attr_value.proto#L23

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/attr_value.proto#L35

Ответы [ 2 ]

2 голосов
/ 01 июня 2019

Методом проб и ошибок я выдыхаю, что это просто:

node.attr["T"].type = 1 # to set type to float32
1 голос
/ 31 мая 2019

Попробуйте передать T как байт:

node.attr["key"].s = b'T'

Если вы хотите передать больше символов, попробуйте класс bytearray.

В определении protobuf для AttrValue s определяется как байты, а не строка. В руководстве по Protobuf говорится, что это должна быть строка в python, но ваша ошибка говорит о том, что она больше похожа на байтовый массив.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...