Я нашел, что проще всего использовать tf.variable_scope
с reuse=tf.AUTO_REUSE
.tf.name_scope
является необязательным, но сохраняет ваши графики чистыми для визуализации тензорной доски.
import tensorflow as tf
def get_logits(image):
with tf.variable_scope('my_network', reuse=tf.AUTO_REUSE):
# more complex network probably
x = image
x = tf.layers.conv2d(x, 3, 1, activation=tf.nn.relu)
x = tf.layers.conv2d(x, 3, 1, activation=tf.nn.relu)
x = tf.layers.flatten(x)
x = tf.layers.dense(x, 10)
return x
batch_size = 2
height = 6
width = 6
# dummy images
image1 = tf.zeros((batch_size, height, width, 3), dtype=tf.float32)
image2 = tf.zeros((batch_size, height, width, 3), dtype=tf.float32)
with tf.name_scope('instance1'):
out1 = get_logits(image1)
print(len(tf.global_variables())) # 6
with tf.name_scope('instance2'):
out2 = get_logits(image2)
print(len(tf.global_variables())) # still 6
Я не уверен в вашей точной проблеме с различными объектами.Если у вас есть несколько разных объектов, просто убедитесь, что они вызывают одну и ту же функцию.
class MyNetwork(object):
def __init__(self, name):
self.name = name
def get_network_logits(self, image):
with tf.name_scope(self.name):
return get_logits(image)
n1 = MyNetwork('instance1')
n2 = MyNetwork('instance2')
l1 = n1(image1)
l2 = n2(image2)