в тензорном потоке, как инициализировать часть переменных - PullRequest
0 голосов
/ 04 июля 2018

Я хочу инициализировать список переменных, и я определил их как список с именем 'block_var'. Я хочу использовать усеченный нормальный метод для их инициализации.

block_var = [v for v in tf.global_variables() if 'block' in v.name]
init_block = tf.variables_initializer(var_list = block_var)

тогда что мне делать? Я пытался

for v in block_var:
    v.initializer = tf.truncated_normal_initializer()

Я тоже пробовал

init_block = tf.truncated_normal_initializer()

оба поля.

Ответы [ 2 ]

0 голосов
/ 04 июля 2018

Обновление: как уже упоминалось в комментариях, в моем предыдущем коде были ошибки. Это создало новые переменные. Поэтому я сделал улучшения и теперь использую assert для проверки.

Это может быть так.

import tensorflow as tf
import numpy as np

with tf.variable_scope("reuse"):
    x = tf.get_variable('x', [5, 5])
    y = tf.get_variable('y', [5, 5])

block_var = [v.name for v in tf.trainable_variables()]

def initialize( name, shape ):
    with tf.variable_scope("reuse",reuse=True):
        x = tf.get_variable(name.split(':')[0][-1],  shape = shape, initializer=tf.random_normal_initializer())
        x.initializer.run()
        print (x.eval())


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    vars = sess.run(block_var)
    for name,shape in zip(block_var,vars):
        initialize( name, shape.shape )
    assert (len(tf.global_variables()) == 2), "Variables are not reused"
0 голосов
/ 04 июля 2018

Раствор 1

Вам нужно передать аргумент инициализатора для get_variable, как

import tensorflow as tf
import numpy as np

ref0 = tf.get_variable('block0', [2], initializer=tf.truncated_normal_initializer(mean=40))
ref1 = tf.get_variable('block1', [2], initializer=tf.truncated_normal_initializer(mean=40))
ref2 = tf.get_variable('block2', [2], initializer=tf.truncated_normal_initializer(mean=40))

ref4 = tf.get_variable('foo0', [2], initializer=tf.truncated_normal_initializer(mean=10))
ref5 = tf.get_variable('foo1', [2], initializer=tf.truncated_normal_initializer(mean=10))

block_vars = [v for v in tf.global_variables() if 'block' in v.name]
block_vars_complement = [v for v in tf.global_variables() if 'block' not in v.name]

with tf.Session() as sess:
    sess.run(tf.variables_initializer(var_list=block_vars))
    print(np.mean(sess.run([ref0, ref1, ref2])), 'should be ~ 40')

    try:
        print(np.mean(sess.run([ref4])))
    except Exception as e:
        print('[INFO] failed as expected with message %s' % e)

    sess.run(tf.variables_initializer(var_list=block_vars_complement))
    print(np.mean(sess.run([ref4, ref5])), 'should be ~ 10')

Раствор 2

Если вы не хотите передавать initializer каждому get_variable, вы можете использовать пользовательский метод получения, такой как

import tensorflow as tf
import numpy as np


def my_getter(getter, name, shape, *args, **kwargs):
    if 'block' not in name:
        return getter(name=name, shape=shape, *args, **kwargs)
    else:
        kwargs['initializer'] = tf.truncated_normal_initializer(mean=40)
        return getter(name=name, shape=shape, *args, **kwargs)

with tf.variable_scope("some_scopename", custom_getter=my_getter):
    ref0 = tf.get_variable('block0', [2], initializer=tf.truncated_normal_initializer(mean=10))
    ref1 = tf.get_variable('block1', [2], initializer=tf.truncated_normal_initializer(mean=10))
    ref2 = tf.get_variable('block2', [2], initializer=tf.truncated_normal_initializer(mean=10))

    ref4 = tf.get_variable('foo0', [2], initializer=tf.truncated_normal_initializer(mean=10))
    ref5 = tf.get_variable('foo1', [2], initializer=tf.truncated_normal_initializer(mean=10))

block_vars = [v for v in tf.global_variables() if 'block' in v.name]
block_vars_complement = [v for v in tf.global_variables() if 'block' not in v.name]

with tf.Session() as sess:
    sess.run(tf.variables_initializer(var_list=block_vars))
    print(np.mean(sess.run([ref0, ref1, ref2])), 'should be ~ 40')

    try:
        print(np.mean(sess.run([ref4])))
    except Exception as e:
        print('[INFO] failed as expected with message %s' % e)

    sess.run(tf.variables_initializer(var_list=block_vars_complement))
    print(np.mean(sess.run([ref4, ref5])), 'should be ~ 10')

Решение 3

tf.truncated_normal_initializer или другие инициализаторы - это просто самостоятельные операции. Следовательно, они могут быть применены в цикле ко всем переменным из коллекции, и такое групповое обновление может быть наконец применено (см. initialize_collection):

import tensorflow as tf
import numpy as np

ref0 = tf.get_variable('block0', [2], initializer=tf.truncated_normal_initializer(mean=40))
ref1 = tf.get_variable('block1', [2], initializer=tf.truncated_normal_initializer(mean=40))
ref2 = tf.get_variable('block2', [2], initializer=tf.truncated_normal_initializer(mean=40))

ref4 = tf.get_variable('foo0', [2], initializer=tf.truncated_normal_initializer(mean=10))
ref5 = tf.get_variable('foo1', [2], initializer=tf.truncated_normal_initializer(mean=10))

block_vars = [v for v in tf.global_variables() if 'block' in v.name]
block_vars_complement = [v for v in tf.global_variables() if 'block' not in v.name]


def initialize_collection(collection, initializer):
    ops = []
    for v in collection:
        ops.append(v.assign(initializer(shape=v.shape)))
    return tf.group(ops)


with tf.Session() as sess:
    sess.run(tf.variables_initializer(var_list=block_vars))
    print(np.mean(sess.run([ref0, ref1, ref2])), 'should be ~ 40')

    sess.run(initialize_collection(block_vars, tf.truncated_normal_initializer(mean=-40, stddev=0.01)))
    print(np.mean(sess.run([ref0, ref1, ref2])), 'should be ~ -40')
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...