Как получить форму тензоров в while_l oop map_fn в файлах тензорного потока pb? - PullRequest
0 голосов
/ 15 января 2020

Мне нужно получить точную форму тензоров для расчета флопов. У меня есть только файл pb. Тем не менее, я получаю «?» в тф.в_д_1 oop. Пример показан ниже:

python -2,7,

tf-gpu == 1.14.0,

cuda-10.0 и cudnn-7:

import tensorflow as tf
from tensorflow import graph_util
from tensorflow import layers

def save():

    # define the architecture and save the graph into a pb file

    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True

    # architecture
    inputs = tf.zeros([16, 30, 1])
    length = 5
    initial_outputs = tf.TensorArray(dtype=tf.float32, size=length)
    initial_t = tf.constant(0)

    def should_continue(t, *args):
        return t < length

    def iteration(t, inputs, outputs_):
        pad_attention = tf.pad(inputs, [[0, 0], [4, 0], [0, 0]], 'CONSTANT')
        attention = layers.conv1d(pad_attention, 1, [5], padding='valid')
        # I can get the shape while using 'sess.run'
        print(attention)
        outputs_ = outputs_.write(t, attention)
        return t + 1, attention, outputs_

    t, attention, outputs = tf.while_loop(should_continue, iteration,
                                  [initial_t, inputs, initial_outputs])
    outputs = outputs.stack()

    with tf.Session(config=config) as sess:
        init = tf.global_variables_initializer()
        sess.run([init])
        sess.run([outputs])

        # save the pb file
        const_graph=graph_util.convert_variables_to_constants(sess,sess.graph_def,['TensorArrayStack/TensorArrayGatherV3'])
        with tf.gfile.FastGFile("pad.pb",mode="wb") as f:
            f.write(const_graph.SerializeToString())

def load():

    # load the pb file and print the tensor information

    input_pb = 'pad.pb'
    tf.reset_default_graph()
    g1 = tf.Graph()
    with g1.as_default():
        with tf.gfile.FastGFile(input_pb, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            _ = tf.import_graph_def(graph_def, name='')

        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as session:
            for tensor in tf.contrib.graph_editor.get_tensors(tf.get_default_graph()):
                # I can't get the exact shape here
                print(tensor)

if __name__ == '__main__':
    save()
    load()

Вывод:

Tensor("zeros:0", shape=(16, 30, 1), dtype=float32)
Tensor("TensorArray/size:0", shape=(), dtype=int32)
Tensor("TensorArray:0", shape=(2,), dtype=resource)
Tensor("TensorArray:1", shape=(), dtype=float32)
Tensor("Const:0", shape=(), dtype=int32)
Tensor("while/Enter:0", dtype=int32)
Tensor("while/Enter_1:0", dtype=float32)
Tensor("while/Enter_2:0", dtype=float32)
Tensor("while/Merge:0", dtype=int32)
Tensor("while/Merge:1", shape=(), dtype=int32)
Tensor("while/Merge_1:0", dtype=float32)
Tensor("while/Merge_1:1", shape=(), dtype=int32)
Tensor("while/Merge_2:0", dtype=float32)
Tensor("while/Merge_2:1", shape=(), dtype=int32)
Tensor("while/Less/y:0", shape=(), dtype=int32)
Tensor("while/Less:0", dtype=bool)
Tensor("while/LoopCond:0", shape=(), dtype=bool)
Tensor("while/Switch:0", dtype=int32)
Tensor("while/Switch:1", dtype=int32)
Tensor("while/Switch_1:0", dtype=float32)
Tensor("while/Switch_1:1", dtype=float32)
Tensor("while/Switch_2:0", dtype=float32)
Tensor("while/Switch_2:1", dtype=float32)
Tensor("while/Identity:0", dtype=int32)
Tensor("while/Identity_1:0", dtype=float32)
Tensor("while/Identity_2:0", dtype=float32)
Tensor("while/Pad/paddings:0", shape=(3, 2), dtype=int32)
**Tensor("while/Pad:0", shape=(?, ?, ?), dtype=float32)**     # Here I get '?', and '?' propagates to other tensors
Tensor("conv1d/kernel:0", shape=(5, 1, 1), dtype=float32)
Tensor("conv1d/kernel/read:0", shape=(5, 1, 1), dtype=float32)
Tensor("conv1d/bias:0", shape=(1,), dtype=float32)
Tensor("conv1d/bias/read:0", shape=(1,), dtype=float32)
Tensor("while/conv1d/conv1d/ExpandDims/dim:0", shape=(), dtype=int32)
Tensor("while/conv1d/conv1d/ExpandDims:0", shape=(?, 1, ?, ?), dtype=float32)
Tensor("while/conv1d/conv1d/ExpandDims_1/dim:0", shape=(), dtype=int32)
Tensor("while/conv1d/conv1d/ExpandDims_1/Enter:0", shape=(5, 1, 1), dtype=float32)
Tensor("while/conv1d/conv1d/ExpandDims_1:0", shape=(1, 5, 1, 1), dtype=float32)
Tensor("while/conv1d/conv1d:0", shape=(?, 1, ?, 1), dtype=float32)
Tensor("while/conv1d/conv1d/Squeeze:0", shape=(?, ?, 1), dtype=float32)
Tensor("while/conv1d/BiasAdd/Enter:0", shape=(1,), dtype=float32)
Tensor("while/conv1d/BiasAdd:0", shape=(?, ?, 1), dtype=float32)
Tensor("while/TensorArrayWrite/TensorArrayWriteV3/Enter:0", shape=(2,), dtype=resource)
Tensor("while/TensorArrayWrite/TensorArrayWriteV3:0", shape=(), dtype=float32)
Tensor("while/add/y:0", shape=(), dtype=int32)
Tensor("while/add:0", dtype=int32)
Tensor("while/NextIteration:0", dtype=int32)
Tensor("while/NextIteration_1:0", shape=(?, ?, 1), dtype=float32)
Tensor("while/NextIteration_2:0", shape=(), dtype=float32)
Tensor("while/Exit_2:0", dtype=float32)
Tensor("TensorArrayStack/TensorArraySizeV3:0", shape=(), dtype=int32)
Tensor("TensorArrayStack/range/start:0", shape=(), dtype=int32)
Tensor("TensorArrayStack/range/delta:0", shape=(), dtype=int32)
Tensor("TensorArrayStack/range:0", shape=(?,), dtype=int32)
Tensor("TensorArrayStack/TensorArrayGatherV3:0", dtype=float32)

Что вызвало Тензор ("while / Pad: 0", shape = (?,?,?), Dtype = float32) а как мне получить правильную форму?

...