TensorFlow py_function - вложенный тип вывода? - PullRequest
3 голосов
/ 19 февраля 2020

Можно ли указать вложенный тип вывода для py_function?

для TensorFlows. Как конкретный случай c, я бы хотел, чтобы py_function имел возврат тип ((tf.float32, tf.float32), (tf.float32, tf.float32)), где отдельные элементы не обязательно имеют одинаковые размеры. Есть ли способ указать это для py_function?

Так же, как некоторое понимание того, почему это полезно в моем случае, у меня есть tf.data.Dataset со списками путей к файлам. py_function берет один из этих путей к файлу и из файла генерирует отрицательный и положительный пример вместе с соответствующими метками, в результате чего получается ((positive_data, positive_label), (negative_data, negative_label)) (обратите внимание, что метки не обязательно являются единичными значениями, но они также не одинаковой формы в качестве входных данных). Этот py_function может быть сопоставлен с набором данных и (с описанной выше структурой) имеет один уровень, сглаженный для создания обучающего набора данных со (data, label) структурированными элементами. Несмотря на то, что возможно иметь обходной путь, при котором данные и метка стэкируются в py_function и позже не стэкируются (или начинаются полностью неструктурированными из функции py_function и только в паре впоследствии), это приводит к грязной и запутанной настройке. Если py_function может напрямую выводить тип ((tf.float32, tf.float32), (tf.float32, tf.float32)), это приведет к более чистой настройке.

1 Ответ

2 голосов
/ 24 февраля 2020

Тип вывода tf.py_function не может быть вложенной последовательностью. Однако при использовании tf.py_function с API tf.data необходимо создать функцию-обертку (tf_foo в приведенном ниже примере), и вы можете вкладывать выходные данные в эту функцию.

import tensorflow as tf

# The python function.
def foo(x):
    return x, x, x, x

# Wrap the python function to make it compatible with `tf.data.Dataset.map`.
def tf_foo(x):
    a, b, c, d = tf.py_function(foo, [x], Tout=[tf.float32, tf.float32, tf.float32, tf.float32])
    return (a, b), (c, d)

dset = tf.data.Dataset.from_tensor_slices([0, 1, 2, 3, 4])
dset.map(tf_foo)
# <MapDataset shapes: ((<unknown>, <unknown>), (<unknown>, <unknown>)),
#  types: ((tf.float32, tf.float32), (tf.float32, tf.float32))>

Это также продемонстрировано в руководстве TensorFlow .

...