Обведите тензор и примените функцию к каждому элементу - PullRequest
3 голосов
/ 28 июня 2019

Я хочу перебрать тензор, который содержит список Int, и применить функцию к каждому из элементов. В функции каждый элемент получит значение из числа python. Я пробовал простой способ с tf.map_fn, который будет работать с функцией add, такой как следующий код:

import tensorflow as tf

def trans_1(x):
    return x+10

a = tf.constant([1, 2, 3])
b = tf.map_fn(trans_1, a)
with tf.Session() as sess:
    res = sess.run(b)
    print(str(res))
# output: [11 12 13]

Но следующий код выдает исключение KeyError: tf.Tensor'map_8/while/TensorArrayReadV3:0' shape=() dtype=int32:

import tensorflow as tf

kv_dict = {1:11, 2:12, 3:13}

def trans_2(x):
    return kv_dict[x]

a = tf.constant([1, 2, 3])
b = tf.map_fn(trans_2, a)
with tf.Session() as sess:
    res = sess.run(b)
    print(str(res))

Моя версия тензорного потока 1.13.1. Спасибо вперед.

Ответы [ 2 ]

0 голосов
/ 28 июня 2019

Есть простой способ достичь того, что вы пытаетесь.

Проблема в том, что функция, переданная в map_fn, должна иметь тензоры в качестве параметров и тензор в качестве возвращаемого значения. Однако ваша функция trans_2 принимает простой питон int в качестве параметра и возвращает другой питон int. Вот почему ваш код не работает.

Тем не менее, TensorFlow предоставляет простой способ обернуть обычные функции Python, а именно tf.py_func, вы можете использовать его в вашем случае следующим образом:

import tensorflow as tf

kv_dict = {1:11, 2:12, 3:13}

def trans_2(x):
    return kv_dict[x]

def wrapper(x):
    return tf.cast(tf.py_func(trans_2, [x], tf.int64), tf.int32)

a = tf.constant([1, 2, 3])
b = tf.map_fn(wrapper, a)
with tf.Session() as sess:
    res = sess.run(b)
    print(str(res))

вы можете видеть, что я добавил функцию-обертку, которая ожидает параметр тензора и возвращает тензор, поэтому его можно использовать в map_fn. Приведение используется потому, что по умолчанию в Python используются 64-разрядные целые числа, а в TensorFlow используются 32-разрядные целые числа.

0 голосов
/ 28 июня 2019

Вы не можете использовать такую ​​функцию, потому что параметр x является тензором TensorFlow, а не значением Python. Таким образом, для того, чтобы это работало, вам также необходимо превратить ваш словарь в тензор, но это не так просто, потому что ключи в словаре могут быть не последовательными.

Вместо этого вы можете решить эту проблему без сопоставления, но вместо этого сделать что-то похожее на , что предлагается здесь для NumPy. В TensorFlow вы можете реализовать это так:

import tensorflow as tf

def replace_by_dict(x, d):
    # Get keys and values from dictionary
    keys, values = zip(*d.items())
    keys = tf.constant(keys, x.dtype)
    values = tf.constant(values, x.dtype)
    # Make a sequence for the range of values in the input
    v_min = tf.reduce_min(x)
    v_max = tf.reduce_max(x)
    r = tf.range(v_min, v_max + 1)
    r_shape = tf.shape(r)
    # Mask replacements that are out of the input range
    mask = (keys >= v_min) & (keys <= v_max)
    keys = tf.boolean_mask(keys, mask)
    values = tf.boolean_mask(values, mask)
    # Replace values in the sequence with the corresponding replacements
    scatter_idx = tf.expand_dims(keys, 1) - v_min
    replace_mask = tf.scatter_nd(
        scatter_idx, tf.ones_like(values, dtype=tf.bool), r_shape)
    replace_values = tf.scatter_nd(scatter_idx, values, r_shape)
    replacer = tf.where(replace_mask, replace_values, r)
    # Gather the replacement value or the same value if it was not modified
    return tf.gather(replacer, x - v_min)

# Test
kv_dict = {1: 11, 2: 12, 3: 13}
with tf.Graph().as_default(), tf.Session() as sess:
    a = tf.constant([1, 2, 3])
    print(sess.run(replace_by_dict(a, kv_dict)))
    # [11, 12, 13]

Это позволит вам иметь значения во входном тензоре без замен (оставлены как есть), а также не требует наличия всех замещающих значений в тензоре. Это должно быть эффективно, если минимальное и максимальное значения в вашем входе не очень далеко.

...