Общие веса для подкласса в сиамской модели в Tensorflow - PullRequest
0 голосов
/ 23 мая 2018

У меня есть некоторые проблемы с организацией моего кода в TENSORFLOW.Я хочу реализовать сиамскую модель, которая сравнивает выходные данные двух сверточных сетей, которые имеют одинаковые веса.

Я хочу создать класс для определения моей сверточной сети и другой класс для определения моей глобальной модели.Кажется, что есть несколько способов разделить веса (ленивая загрузка, использовать много областей, ...), но как я могу сделать это между многими объектами?

Полезны ли флаги в моем случае?

Любая помощь будет полезна

1 Ответ

0 голосов
/ 23 мая 2018

Я нашел, что проще всего использовать tf.variable_scope с reuse=tf.AUTO_REUSE.tf.name_scope является необязательным, но сохраняет ваши графики чистыми для визуализации тензорной доски.

import tensorflow as tf

def get_logits(image):
  with tf.variable_scope('my_network', reuse=tf.AUTO_REUSE):
    # more complex network probably
    x = image
    x = tf.layers.conv2d(x, 3, 1, activation=tf.nn.relu)
    x = tf.layers.conv2d(x, 3, 1, activation=tf.nn.relu)
    x = tf.layers.flatten(x)
    x = tf.layers.dense(x, 10)
    return x


batch_size = 2
height = 6
width = 6

# dummy images
image1 = tf.zeros((batch_size, height, width, 3), dtype=tf.float32)
image2 = tf.zeros((batch_size, height, width, 3), dtype=tf.float32)

with tf.name_scope('instance1'):
  out1 = get_logits(image1)

print(len(tf.global_variables()))  # 6

with tf.name_scope('instance2'):
  out2 = get_logits(image2)

print(len(tf.global_variables()))  # still 6

Я не уверен в вашей точной проблеме с различными объектами.Если у вас есть несколько разных объектов, просто убедитесь, что они вызывают одну и ту же функцию.

class MyNetwork(object):
  def __init__(self, name):
    self.name = name

  def get_network_logits(self, image):
    with tf.name_scope(self.name):
      return get_logits(image)

n1 = MyNetwork('instance1')
n2 = MyNetwork('instance2')

l1 = n1(image1)
l2 = n2(image2)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...