Мне нужно получить точную форму тензоров для расчета флопов. У меня есть только файл pb. Тем не менее, я получаю «?» в тф.в_д_1 oop. Пример показан ниже:
python -2,7,
tf-gpu == 1.14.0,
cuda-10.0 и cudnn-7:
import tensorflow as tf
from tensorflow import graph_util
from tensorflow import layers
def save():
# define the architecture and save the graph into a pb file
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
# architecture
inputs = tf.zeros([16, 30, 1])
length = 5
initial_outputs = tf.TensorArray(dtype=tf.float32, size=length)
initial_t = tf.constant(0)
def should_continue(t, *args):
return t < length
def iteration(t, inputs, outputs_):
pad_attention = tf.pad(inputs, [[0, 0], [4, 0], [0, 0]], 'CONSTANT')
attention = layers.conv1d(pad_attention, 1, [5], padding='valid')
# I can get the shape while using 'sess.run'
print(attention)
outputs_ = outputs_.write(t, attention)
return t + 1, attention, outputs_
t, attention, outputs = tf.while_loop(should_continue, iteration,
[initial_t, inputs, initial_outputs])
outputs = outputs.stack()
with tf.Session(config=config) as sess:
init = tf.global_variables_initializer()
sess.run([init])
sess.run([outputs])
# save the pb file
const_graph=graph_util.convert_variables_to_constants(sess,sess.graph_def,['TensorArrayStack/TensorArrayGatherV3'])
with tf.gfile.FastGFile("pad.pb",mode="wb") as f:
f.write(const_graph.SerializeToString())
def load():
# load the pb file and print the tensor information
input_pb = 'pad.pb'
tf.reset_default_graph()
g1 = tf.Graph()
with g1.as_default():
with tf.gfile.FastGFile(input_pb, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
with tf.Session(config=config) as session:
for tensor in tf.contrib.graph_editor.get_tensors(tf.get_default_graph()):
# I can't get the exact shape here
print(tensor)
if __name__ == '__main__':
save()
load()
Вывод:
Tensor("zeros:0", shape=(16, 30, 1), dtype=float32)
Tensor("TensorArray/size:0", shape=(), dtype=int32)
Tensor("TensorArray:0", shape=(2,), dtype=resource)
Tensor("TensorArray:1", shape=(), dtype=float32)
Tensor("Const:0", shape=(), dtype=int32)
Tensor("while/Enter:0", dtype=int32)
Tensor("while/Enter_1:0", dtype=float32)
Tensor("while/Enter_2:0", dtype=float32)
Tensor("while/Merge:0", dtype=int32)
Tensor("while/Merge:1", shape=(), dtype=int32)
Tensor("while/Merge_1:0", dtype=float32)
Tensor("while/Merge_1:1", shape=(), dtype=int32)
Tensor("while/Merge_2:0", dtype=float32)
Tensor("while/Merge_2:1", shape=(), dtype=int32)
Tensor("while/Less/y:0", shape=(), dtype=int32)
Tensor("while/Less:0", dtype=bool)
Tensor("while/LoopCond:0", shape=(), dtype=bool)
Tensor("while/Switch:0", dtype=int32)
Tensor("while/Switch:1", dtype=int32)
Tensor("while/Switch_1:0", dtype=float32)
Tensor("while/Switch_1:1", dtype=float32)
Tensor("while/Switch_2:0", dtype=float32)
Tensor("while/Switch_2:1", dtype=float32)
Tensor("while/Identity:0", dtype=int32)
Tensor("while/Identity_1:0", dtype=float32)
Tensor("while/Identity_2:0", dtype=float32)
Tensor("while/Pad/paddings:0", shape=(3, 2), dtype=int32)
**Tensor("while/Pad:0", shape=(?, ?, ?), dtype=float32)** # Here I get '?', and '?' propagates to other tensors
Tensor("conv1d/kernel:0", shape=(5, 1, 1), dtype=float32)
Tensor("conv1d/kernel/read:0", shape=(5, 1, 1), dtype=float32)
Tensor("conv1d/bias:0", shape=(1,), dtype=float32)
Tensor("conv1d/bias/read:0", shape=(1,), dtype=float32)
Tensor("while/conv1d/conv1d/ExpandDims/dim:0", shape=(), dtype=int32)
Tensor("while/conv1d/conv1d/ExpandDims:0", shape=(?, 1, ?, ?), dtype=float32)
Tensor("while/conv1d/conv1d/ExpandDims_1/dim:0", shape=(), dtype=int32)
Tensor("while/conv1d/conv1d/ExpandDims_1/Enter:0", shape=(5, 1, 1), dtype=float32)
Tensor("while/conv1d/conv1d/ExpandDims_1:0", shape=(1, 5, 1, 1), dtype=float32)
Tensor("while/conv1d/conv1d:0", shape=(?, 1, ?, 1), dtype=float32)
Tensor("while/conv1d/conv1d/Squeeze:0", shape=(?, ?, 1), dtype=float32)
Tensor("while/conv1d/BiasAdd/Enter:0", shape=(1,), dtype=float32)
Tensor("while/conv1d/BiasAdd:0", shape=(?, ?, 1), dtype=float32)
Tensor("while/TensorArrayWrite/TensorArrayWriteV3/Enter:0", shape=(2,), dtype=resource)
Tensor("while/TensorArrayWrite/TensorArrayWriteV3:0", shape=(), dtype=float32)
Tensor("while/add/y:0", shape=(), dtype=int32)
Tensor("while/add:0", dtype=int32)
Tensor("while/NextIteration:0", dtype=int32)
Tensor("while/NextIteration_1:0", shape=(?, ?, 1), dtype=float32)
Tensor("while/NextIteration_2:0", shape=(), dtype=float32)
Tensor("while/Exit_2:0", dtype=float32)
Tensor("TensorArrayStack/TensorArraySizeV3:0", shape=(), dtype=int32)
Tensor("TensorArrayStack/range/start:0", shape=(), dtype=int32)
Tensor("TensorArrayStack/range/delta:0", shape=(), dtype=int32)
Tensor("TensorArrayStack/range:0", shape=(?,), dtype=int32)
Tensor("TensorArrayStack/TensorArrayGatherV3:0", dtype=float32)
Что вызвало Тензор ("while / Pad: 0", shape = (?,?,?), Dtype = float32) а как мне получить правильную форму?