Как восстановить висящий tf.py_func в tf.data.Dataset () с помощью API tf.saved_model? - PullRequest
0 голосов
/ 16 апреля 2019

После тщательного исследования восстановления tf.py_func() при использовании API save_model, я не смог найти другую информацию, кроме документированной в тензор потока :

Операция должна выполняться в том же адресном пространстве, что и программа Python, которая вызывает tf.py_func(). Если вы используете распределенный TensorFlow, вы должны запустить tf.train.Server в том же процессе, что и программа, которая вызывает tf.py_func(), и вы должны прикрепить созданную операцию к устройству на этом сервере (например, с помощью tf.device():)

Два фрагмента сохранения / загрузки помогают проиллюстрировать ситуацию.

Сохранить часть:

def wrapper(x, y):
    with tf.name_scope('wrapper'):
        return tf.py_func(Copy, [x, y], [tf.float32, tf.float32])

def Copy(x, y):
    return x, y

x_ph = tf.placeholder(tf.float32, [None], 'x_ph')
y_ph = tf.placeholder(tf.float32, [None], 'y_ph')

with tf.name_scope('input'):
    ds = tf.data.Dataset.from_tensor_slices((x_ph, y_ph))
    ds = ds.map(wrapper)
    ds = ds.batch(1)
    it = tf.data.Iterator.from_structure(ds.output_types, ds.output_shapes)
    it_init_op = it.make_initializer(ds, name='it_init_op')

x_it, y_it = it.get_next()

# Simple operation
with tf.name_scope('add'):
    res = tf.add(x_it, y_it)

with tf.Session() as sess:
    sess.run([tf.global_variables_initializer(), it_init_op], feed_dict={y_ph: [10] * 10, x_ph: [i for i in range(10)]})
    sess.run([res])
    tf.saved_model.simple_save(sess, './dummy/test', {'x_ph': x_ph, 'y_ph': y_ph}, {'res': res})

Загрузка части:

graph = tf.Graph()
graph.as_default()
with tf.Session(graph=graph) as sess:
    tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], './dummy/test')

    res = graph.get_tensor_by_name('add/Add:0')
    it_init_op = graph.get_operation_by_name('input/it_init_op')
    x_ph = graph.get_tensor_by_name('x_ph:0')
    y_ph = graph.get_tensor_by_name('y_ph:0')
    sess.run([it_init_op], feed_dict={x_ph: [5] * 5, y_ph: [i for i in range(5)]})

    for _ in range(5):
        sess.run([res])

Ошибка:

ValueError: обратный вызов pyfunc_0 не найден

Хорошо известно, что функция, заключенная в tf.py_func(), не сохраняется вместе с моделью. У кого-нибудь есть решение, чтобы восстановить это, используя небольшую подсказку, заданную tf doc, применяя tf.train.Server

...