Вам необходимо определить пользовательскую функцию для подачи в tf.map_fn()
- Tensorflow dox
Функции Mapper отображают (как ни странно) существующий объект (тензор) в новый, используя определяемую вами функцию.
Они применяют пользовательскую функцию к каждому элементу в объекте, без всякого шумихи вокруг циклов for
.
Например (не проверенный код, может не работать - на моем телефоне атм):
def custom(a):
b = a + 1
return b
original = np.array([2,2,2])
mapped = tf.map_fn(custom, original)
# mapped == [3, 3, 3] ... hopefully
Во всех примерах Tensorflow используются функции lambda
, поэтому вам может потребоваться определить ваши функции таким образом, если вышеперечисленное не работает. Пример тензорного потока:
elems = np.array([1, 2, 3, 4, 5, 6])
squares = map_fn(lambda x: x * x, elems)
# squares == [1, 4, 9, 16, 25, 36]
Редактировать:
Кроме того, функции карты намного проще распараллелить, чем для циклов - предполагается, что каждый элемент объекта обрабатывается уникально - так что вы можете увидеть повышение производительности, используя их.
Редактировать 2:
Что касается части "уменьшить сумму, но не по этому индексу", я настоятельно рекомендую вам начать оглядываться на матричные операции ... Как уже упоминалось, map
функции работают поэлементно - они не знают о других элементах , Функция reduce
- это то, что вам нужно, но даже если вы пытаетесь и суммируете «не этот индекс», они даже финишируют, а тензорный поток строится на основе матричных операций ... Не парадигмы MapReduce.
Что-то в этом роде может помочь:
sess = tf.Session()
var = np.ones([3, 3, 3]) * 5
zero_identity = tf.linalg.set_diag(
var, tf.zeros(var.shape[0:-1], dtype=tf.float64)
)
exp_one = tf.exp(var)
exp_two = tf.exp(zero_identity)
summed = tf.reduce_sum(exp_two, axis = [0,1])
final = exp_one / summed
print("input matrix: \n", var, "\n")
print("Identities of the matrix to Zero: \n", zero_identity.eval(session=sess), "\n")
print("Exponential Values numerator: \n", exp_one.eval(session=sess), "\n")
print("Exponential Values to Sum: \n", exp_two.eval(session=sess), "\n")
print("Summed values for zero identity matrix\n ... along axis [0,1]: \n", summed.eval(session=sess), "\n")
print("Output:\n", final.eval(session=sess), "\n")