Имя переменной не работает для init_from_checkpoint? - PullRequest
0 голосов
/ 31 декабря 2018

Пожалуйста, смотрите эту игрушечную модель:

import tensorflow as tf
import os

if not os.path.isdir('./temp'):
    os.mkdir('./temp')


def create_and_save_varialbe(sess=tf.Session()):
    a = tf.get_variable("a", [])
    saver_a = tf.train.Saver({"a": a})
    init = tf.global_variables_initializer()
    sess.run(init)
    saver_a.save(sess, './temp/temp_model')
    a = sess.run(a)
    print('the initialized a is %f' % a)
    return a


def init_variable(sess=tf.Session()):
    b = tf.Variable(tf.constant(1.0, shape=[]), name="b", dtype=tf.float32) 
    tf.train.init_from_checkpoint('./temp/temp_model', 
            {'a': 'b'})
    init = tf.global_variables_initializer()
    sess.run(init)
    b = sess.run(b)
    print(b)
    return b


def init_get_variable(sess=tf.Session()):
    c = tf.get_variable("c", shape=[])
    tf.train.init_from_checkpoint('./temp/temp_model', 
            {'a': 'c'})
    init = tf.global_variables_initializer()
    sess.run(init)
    c = sess.run(c)
    print(c)
    return c


a = create_and_save_varialbe()
b = init_variable()
c = init_get_variable()

Функция init_get_varialbe работает, но не функция init_variable.

ValueError: Карта назначений с именем только области действия должна соответствовать только области действия a.Должно быть 'scope /': 'other_scope /'.

Почему имя переменной, определенной переменной, не работает в этом сценарии, и как я могу ее решить?

Версия Tensorflow: 1.12

1 Ответ

0 голосов
/ 31 декабря 2018

Это из-за разницы между переменной и get_variable.

Существует два способа решения этой проблемы:

1) введите переменную, отличную от ее имени.

def init_variable(sess=tf.Session()):
    b = tf.Variable(tf.constant(1.0, shape=[]), name="b", dtype=tf.float32) 
    tf.train.init_from_checkpoint('./temp/temp_model', 
            {'a': b})
    init = tf.global_variables_initializer()
    sess.run(init)
    b = sess.run(b)
    print(b)
    return b

Поскольку, если это переменная, тензор потока может получить его напрямую :

if _is_variable(current_var_or_name) or (
    isinstance(current_var_or_name, list)
    and all(_is_variable(v) for v in current_var_or_name)):
  var = current_var_or_name

В противном случае он должен получить переменную из хранилища переменных :

  store_vars = vs._get_default_variable_store()._vars 

Но переменная, определенная переменной, отсутствует в коллекции ('varstore_key',), как объясняется в этом ответе .

Затем 2) вы можете добавить его в коллекцию самостоятельно:

from tensorflow.python.ops.variable_scope import _VariableStore
from tensorflow.python.framework import ops
def init_variable(sess=tf.Session()):
    b = tf.Variable(tf.constant(1.0, shape=[]), name="b", dtype=tf.float32) 
    store = _VariableStore()
    store._vars = {'b': b}
    ops.add_to_collection(('__variable_store',), store)
    tf.train.init_from_checkpoint('./temp/temp_model', 
            {'a': 'b'})
    init = tf.global_variables_initializer()
    sess.run(init)
    b = sess.run(b)
    print(b)
    return b

Обе работы.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...