Тензор суммирования по разному количеству строк в столбце - PullRequest
1 голос
/ 24 апреля 2020

Я хочу иметь возможность выполнить сумму уменьшения по тензорному значению в Tensorflow, где для каждого столбца мы суммируем только по подмножеству строк. Для иллюстрации рассмотрим следующее

import tensorflow as tf


X = tf.constant(
    [
        [1, 3, 2],
        [0, 5, 8],
        [1, 6, 2]
    ],
    tf.float32
)

row_max = tf.constant([3, 2, 1], tf.int64)

Затем я хочу сделать следующее в Tensorflow, чтобы градиенты могли течь:

partial_sum = 0.0

for col_idx in range(X.shape[1]):
    partial_sum += tf.reduce_sum(X[:row_max[col_idx], col_idx]])

Это должно дать мне 1+0+1+3+5+2 = 12

Однако я не знаю, как это сделать в Tensorflow. Я рассмотрел несколько различных методов, tf.ragged.range, tf.segment_sum et c. Я думаю, что tf.gather_nd может работать, но даже тогда я не уверен, как построить тензор индекса. В Numpy я мог бы сделать что-то вроде:

import numpy as np


X_np = X.numpy()

idx0 = np.concatenate(
    [
        i * np.ones(row_max[i])
        for i in range X_np.shape[1]
    ],
    axis=0
).astype(np.int64)

idx1 = np.concatenate(
    [
        np.arange(row_max[i])
        for i in range X_np.shape[1]
    ],
    axis=0
).astype(np.int64)


X_np[idx0, idx1].sum()

Каков наилучший способ достижения sh моей цели в Tensorflow?

1 Ответ

1 голос
/ 24 апреля 2020

Вот простой способ сделать это:

import tensorflow as tf

x = tf.constant(
    [
        [1, 3, 2],
        [0, 5, 8],
        [1, 6, 2]
    ],
    tf.float32
)
row_max = tf.constant([3, 2, 1], tf.int64)

# Make mask for each column
row_idx = tf.range(tf.shape(x, out_type=row_max.dtype)[0])
mask = tf.expand_dims(row_idx, 1) < row_max
mask_f = tf.dtypes.cast(mask, x.dtype)
# Mask elements and sum
result = tf.reduce_sum(mask_f * x)
tf.print(result)
# 12

# Alternatively, you can mask the elements and sum
result = tf.reduce_sum(tf.boolean_mask(x, mask))
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...