TensorFlow import_meta_graph для нескольких графиков с дубликатами / конфликтами имен - PullRequest
0 голосов
/ 19 июня 2019

Отказ от ответственности: Этот вопрос является продолжением предыдущего вопроса .

Можно ли импортировать несколько графиков TensorFlow с переменными с одинаковыми именами?Насколько я понимаю, по умолчанию существующие переменные будут перезаписаны на tf.train.import_meta_graph().Пример в ответе на другой вопрос показывает пример того, как это сделать с переменными с разными именами:

import tensorflow as tf

# The variables v1 and v2 that we want to restore
v1 = tf.Variable(tf.zeros([1]), name="v1")
v2 = tf.Variable(tf.zeros([1]), name="v2")

# saver1 will only look for v1
saver1 = tf.train.Saver([v1])
# saver2 will only look for v2
saver2 = tf.train.Saver([v2])
with tf.Session() as sess:
    saver1.restore(sess, "tmp/v1.ckpt")
    saver2.restore(sess, "tmp/v2.ckpt")
    print sess.run(v1)
    print sess.run(v2)

Переменные v1 и v2 ранее были сохранены из разныхграфики и теперь оба доступны в графике по умолчанию TensorFlow.

Однако при работе с tensor_forest.RandomForestGraphs() переменные имеют фиксированные имена (например, device_dummy_1).При попытке импортировать несколько таких графиков в какой-то момент возникают ошибки:

NotFoundError: Key device_dummy_100 not found in checkpoint
     [[{{node save/RestoreV2}}]]
     [[{{node GroupCrossDeviceControlEdges_0/save/restore_all}}]]

Насколько я понимаю: у меня есть несколько RF (здесь: 3), у всех из которых есть 130 деревьев, и еще один RF только с 100 деревьями,Когда импортируется последний (меньший), деревья с 101 по 130 не обнаруживаются в импортированном файле, и импортер жалуется на эти отсутствующие переменные.Следовательно, я должен предположить, что импорт второго RF перезаписывает предыдущий.Это правильно?

В итоге у меня есть следующие проблемы:

  • tensor_forest.RandomForestGraphs() не позволяет, например, префикс внутренних имен переменных - разные RF имеют одинаковые имена переменных
  • Импорт графиков с одинаковыми именами переменных перезаписывает существующие переменные

Есть ли способ изменить (префикс) всех имен переменных в одном RF до экспорта или во время импорта?Или есть какое-то другое решение для этого?

РЕДАКТИРОВАТЬ: Хотя использование переменной области кажется многообещающим, я создал этот минимальный пример, чтобы продемонстрировать проблему, с которой я все еще сталкиваюсь:

import tensorflow as tf
from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.python.ops import resources

num_trees = {
    0: 3,
    1: 2,
}

# create two RFs, the first one with 3 trees, the second one with 2 trees
# resulting RFs are stored to two files separately

g0 = tf.Graph()
with g0.as_default():
    base_label = 0
    hparams = tensor_forest.ForestHParams(num_classes=2, num_features=2, num_trees=num_trees[base_label], max_nodes=100).fill()
    rf0 = tensor_forest.RandomForestGraphs(hparams)
    init_vars = tf.group(tf.global_variables_initializer(), resources.initialize_resources(resources.shared_resources()))
    sess = tf.Session()
    sess.run(init_vars)
    X = tf.placeholder(tf.float32, shape=[None, rf0.params.num_features], name="X")
    Y = tf.placeholder(tf.int8, shape=[None], name="Y")
    infer_op = tf.cast(tf.argmax(rf0.inference_graph(X)[0], 1, output_type=tf.int32), tf.int8, name="infer_op")
    for var in tf.global_variables():
        print("RF0: global variable: {}".format(var.name))
    s = tf.train.Saver()
    s.save(sess, "rf0.tfsess")

g1 = tf.Graph()
with g1.as_default():
    base_label = 1
    hparams = tensor_forest.ForestHParams(num_classes=2, num_features=2, num_trees=num_trees[base_label], max_nodes=100).fill()
    rf1 = tensor_forest.RandomForestGraphs(hparams)
    init_vars = tf.group(tf.global_variables_initializer(), resources.initialize_resources(resources.shared_resources()))
    sess = tf.Session()
    sess.run(init_vars)
    X = tf.placeholder(tf.float32, shape=[None, rf1.params.num_features], name="X")
    Y = tf.placeholder(tf.int8, shape=[None], name="Y")
    infer_op = tf.cast(tf.argmax(rf1.inference_graph(X)[0], 1, output_type=tf.int32), tf.int8, name="infer_op")
    for var in tf.global_variables():
        print("RF1: global variable: {}".format(var.name))
    s = tf.train.Saver()
    s.save(sess, "rf1.tfsess")


# re-create/import both RFs into one graph, "subgraph" using variable scope

tf.reset_default_graph()
assert len(tf.global_variables()) == 0

# first RF
base_label = 0
vs = "{}".format(base_label)
with tf.variable_scope(vs):
    hparams = tensor_forest.ForestHParams(num_classes=2, num_features=2, num_trees=num_trees[base_label], max_nodes=100).fill()
    rf0 = tensor_forest.RandomForestGraphs(hparams)
    init_vars = tf.group(tf.global_variables_initializer(), resources.initialize_resources(resources.shared_resources()))
    sess0 = tf.Session()
    sess0.run(init_vars)
    X0 = tf.placeholder(tf.float32, shape=[None, rf0.params.num_features], name="X")
    Y0 = tf.placeholder(tf.int8, shape=[None], name="Y")
    infer_op0 = tf.cast(tf.argmax(rf0.inference_graph(X0)[0], 1, output_type=tf.int32), tf.int8, name="infer_op")

# second RF
base_label = 1
vs = "{}".format(base_label)
with tf.variable_scope(vs):
    hparams = tensor_forest.ForestHParams(num_classes=2, num_features=2, num_trees=num_trees[base_label], max_nodes=100).fill()
    rf1 = tensor_forest.RandomForestGraphs(hparams)
    init_vars = tf.group(tf.global_variables_initializer(), resources.initialize_resources(resources.shared_resources()))
    sess1 = tf.Session()
    sess1.run(init_vars)
    X1 = tf.placeholder(tf.float32, shape=[None, rf1.params.num_features], name="X")
    Y1 = tf.placeholder(tf.int8, shape=[None], name="Y")
    infer_op1 = tf.cast(tf.argmax(rf1.inference_graph(X1)[0], 1, output_type=tf.int32), tf.int8, name="infer_op")

# check that there are only 5 variables (3 "0/device_dummy_#" and 2 "1/device_dummy_#")
for var in tf.global_variables():
    print("global variable: {}".format(var.name))


# create input map for both graphs and import

# first RF
base_label = 0
vs = "{}".format(base_label)
input_map = {}
for i in range(num_trees[base_label]):
    t = tf.get_default_graph().get_tensor_by_name("{}/device_dummy_{}:0".format(vs, i))
    input_map["device_dummy_{}:0".format(i)] = t
    #print(input_map["device_dummy_{}:0".format(i)])
input_map["X:0"] = X0
input_map["Y:0"] = Y0
input_map["infer_op"] = infer_op0
print(input_map)
# {'device_dummy_0:0': <tf.Tensor '0/device_dummy_0:0' shape=(0,) dtype=float32_ref>, 'device_dummy_1:0': <tf.Tensor '0/device_dummy_1:0' shape=(0,) dtype=float32_ref>, 'device_dummy_2:0': <tf.Tensor '0/device_dummy_2:0' shape=(0,) dtype=float32_ref>, 'X:0': <tf.Tensor '0/X:0' shape=(?, 2) dtype=float32>, 'Y:0': <tf.Tensor '0/Y:0' shape=(?,) dtype=int8>, 'infer_op': <tf.Tensor '0/infer_op:0' shape=(?,) dtype=int8>}
s = tf.train.import_meta_graph("{}.meta".format("rf0.tfsess"), input_map=input_map)
s.restore(sess, "rf0.tfsess")

# second RF
base_label = 1
vs = "{}".format(base_label)
input_map = {}
for i in range(num_trees[base_label]):
    t = tf.get_default_graph().get_tensor_by_name("{}/device_dummy_{}:0".format(vs, i))
    input_map["device_dummy_{}:0".format(i)] = t
    #print(input_map["device_dummy_{}:0".format(i)])
input_map["X:0"] = X1
input_map["Y:0"] = Y1
input_map["infer_op"] = infer_op1
print(input_map)
# {'device_dummy_0:0': <tf.Tensor '1/device_dummy_0:0' shape=(0,) dtype=float32_ref>, 'device_dummy_1:0': <tf.Tensor '1/device_dummy_1:0' shape=(0,) dtype=float32_ref>, 'X:0': <tf.Tensor '1/X:0' shape=(?, 2) dtype=float32>, 'Y:0': <tf.Tensor '1/Y:0' shape=(?,) dtype=int8>, 'infer_op': <tf.Tensor '1/infer_op:0' shape=(?,) dtype=int8>}
s = tf.train.import_meta_graph("{}.meta".format("rf1.tfsess"), input_map=input_map)
s.restore(sess, "rf1.tfsess")

for var in tf.global_variables():
    print("global variable: {}".format(var.name))
# global variable: 0/device_dummy_0:0
# global variable: 0/device_dummy_1:0
# global variable: 0/device_dummy_2:0
# global variable: 1/device_dummy_0:0
# global variable: 1/device_dummy_1:0
# global variable: device_dummy_0:0
# global variable: device_dummy_1:0
# global variable: device_dummy_2:0
# global variable: device_dummy_0:0
# global variable: device_dummy_1:0

Видно, что в конце концов импортированные переменные не отображаются на правильные.Вместо отображения на scope/var_name функция import_meta_graph() по-прежнему создает переменную var_name (без области видимости).В этом примере ошибка не возникает (я не знаю почему), но для моего приложения при попытке импортировать меньший RF после большего, происходит NotFoundError.Это еще одна странная проблема, потому что переменные на самом деле имеют одинаковое имя (дубликат для device_dummy_{0,1}:0) в конце.

...