Суммирование смежных ненулевых тензорных значений - PullRequest
0 голосов
/ 20 декабря 2018

Я пытаюсь найти суммирование смежных ненулевых значений тензора, как показано ниже

Скажем, у меня есть тензор A = [1.3, 0.0, 0.6, 0.7, 0.8].И я хочу

1) суммировать непрерывные ненулевые значения тензора для вывода [1.3, 0.0, 2.1], а затем выбрать максимальное значение, равное 2.1.

2) также найти индексы, которые были использованы для суммирования этих значений.В этом случае это будет 2, 3, 4.

Ответы [ 2 ]

0 голосов
/ 21 декабря 2018

Мой подход несколько отличается от @Anwarvic.Я пытаюсь сделать это за один проход.Смотрите функцию ниже.Он перемещается по массиву и ведет журнал максимального значения, которое он видел до сих пор, и текущей суммы.Текущая сумма обновляется до 0, если мы достигаем нуля или суммируем значение до текущего, если оно не равно нулю.

def find_continguous_max_sum(t): 
    max_,cur=0,0 
    max_indices, cur_indices = [], []

    for idx, i in enumerate(t): 
        if (i==0.0): 
            cur = 0 
            cur_indices = []
        else: 
            cur = cur+i
            cur_indices.append(idx)
        if (max_ < cur): 
            max_ = cur 
            max_indices = cur_indices 

    return max_, max_indices

#usage
find_contiguous_max_sum(A)
0 голосов
/ 21 декабря 2018

Мы можем решить эту проблему в два простых шага:

Сначала разделим основной тензор по нулю.Итак, данный тензор A будет выглядеть так [ [1.3], [0], [0.6, 0.7, 0.8] ].Это можно сделать с помощью следующей функции:

def split_list(lst, value=0):
    """
    Splits a given list based on a given value
    default is zero
    """
    groups = []
    sub_group = []
    for i in lst:
        if i == 0:
            groups.append(sub_group)
            sub_group = []
            groups.append([0])
        else:
            sub_group.append(i)
    if sub_group:
        groups.append(sub_group)
    return groups

Во-вторых, суммируйте каждую подгруппу.Возвращенные индексы будут немного хитрыми.Итак, давайте посмотрим на это в коде:

def get_max_indices(groups):
    """
    This function takes a list of lists and 
    returns the indices of the maximum elements
    """
    maximum = 0
    max_length = 0
    total_elements = 0
    length_before = 0
    for idx, sub_group in enumerate(groups):
        summation = sum(sub_group)
        if summation > maximum:
            maximum = summation
            max_length = len(sub_group)
            length_before = total_elements
        total_elements += len(sub_group)
    return [_ for _ in range(length_before, length_before+max_length)]

Теперь давайте попробуем оба:

>>> lst = [1.3, 0, 0.6, 0.7, 0.8]
>>> groups = split_list(lst, value=0)
>>> print(get_max_indices(groups))
[2, 3, 4]

Давайте попробуем другой пример:

>>> lst = [1, 2, 3, 0, 6, 9, 0, 10]
>>> groups = split_list(lst, value=0)
>>> print(get_max_indices(groups))
[4, 5]

Я надеюсь, что эторешает ваши вопросы.Я знаю, что это немного сложнее, чем вы думаете, но это поможет вам начать.Я думаю, что это можно оптимизировать и очистить, но я оставлю это вам.

...