Вы не можете использовать такую функцию, потому что параметр x
является тензором TensorFlow, а не значением Python. Таким образом, для того, чтобы это работало, вам также необходимо превратить ваш словарь в тензор, но это не так просто, потому что ключи в словаре могут быть не последовательными.
Вместо этого вы можете решить эту проблему без сопоставления, но вместо этого сделать что-то похожее на , что предлагается здесь для NumPy. В TensorFlow вы можете реализовать это так:
import tensorflow as tf
def replace_by_dict(x, d):
# Get keys and values from dictionary
keys, values = zip(*d.items())
keys = tf.constant(keys, x.dtype)
values = tf.constant(values, x.dtype)
# Make a sequence for the range of values in the input
v_min = tf.reduce_min(x)
v_max = tf.reduce_max(x)
r = tf.range(v_min, v_max + 1)
r_shape = tf.shape(r)
# Mask replacements that are out of the input range
mask = (keys >= v_min) & (keys <= v_max)
keys = tf.boolean_mask(keys, mask)
values = tf.boolean_mask(values, mask)
# Replace values in the sequence with the corresponding replacements
scatter_idx = tf.expand_dims(keys, 1) - v_min
replace_mask = tf.scatter_nd(
scatter_idx, tf.ones_like(values, dtype=tf.bool), r_shape)
replace_values = tf.scatter_nd(scatter_idx, values, r_shape)
replacer = tf.where(replace_mask, replace_values, r)
# Gather the replacement value or the same value if it was not modified
return tf.gather(replacer, x - v_min)
# Test
kv_dict = {1: 11, 2: 12, 3: 13}
with tf.Graph().as_default(), tf.Session() as sess:
a = tf.constant([1, 2, 3])
print(sess.run(replace_by_dict(a, kv_dict)))
# [11, 12, 13]
Это позволит вам иметь значения во входном тензоре без замен (оставлены как есть), а также не требует наличия всех замещающих значений в тензоре. Это должно быть эффективно, если минимальное и максимальное значения в вашем входе не очень далеко.