Tensorflow: получить все тензоры, начиная с «Добавить» - PullRequest
0 голосов
/ 26 апреля 2018

Прямо сейчас я делаю следующее, чтобы получить доступ к тензорам в моем графике

graph = tf.get_default_graph()
add_0 = graph.get_tensor_by_name("Add:0")
add_1 = graph.get_tensor_by_name("Add_1:0")
add_2 = graph.get_tensor_by_name("Add_2:0")

Когда график короткий, такой подход приемлем. Но для более длинных графиков это становится действительно скучным.

Есть ли способ собрать все тензоры, начинающиеся с Add, чистым способом? Что-то вроде:

add = []
for Add in graph.get_tensors_by_name():
    add.append(Add)

(я знаю, что этот псевдокод действительно неправильный)

такой, что я получаю add = [add_0, add_1, add_2, ... ]

Позже я хочу использовать это для этого: sess.run(add, feed_dict={input: data})

1 Ответ

0 голосов
/ 26 апреля 2018

Вы можете получить все тензоры с помощью sess.graph.get_operations(), а затем использовать startswith(), чтобы выбрать те, которые вам нужны. Протестированный код:

import tensorflow as tf

a = tf.constant( [ 1.0 ] )
b = tf.constant( [ 2.0 ] )
c = tf.add( a, b )
d = tf.add( c, b )

with tf.Session() as sess:

    tensors = sum( [ operation.outputs
                             for operation in sess.graph.get_operations() 
                             if operation.name.startswith( "Add") ],
                   [] )
    print( tensors )
    print( sess.run( tensors ) )

Выходы:

tf.Tensor 'Add: 0' shape = (1,) dtype = float32, tf.Tensor 'Add_1: 0' shape = (1,) dtype = float32
[массив ([3.], dtype = float32), массив ([5.], dtype = float32)]

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...