Мы можем решить эту проблему в два простых шага:
Сначала разделим основной тензор по нулю.Итак, данный тензор A
будет выглядеть так [ [1.3], [0], [0.6, 0.7, 0.8] ]
.Это можно сделать с помощью следующей функции:
def split_list(lst, value=0):
"""
Splits a given list based on a given value
default is zero
"""
groups = []
sub_group = []
for i in lst:
if i == 0:
groups.append(sub_group)
sub_group = []
groups.append([0])
else:
sub_group.append(i)
if sub_group:
groups.append(sub_group)
return groups
Во-вторых, суммируйте каждую подгруппу.Возвращенные индексы будут немного хитрыми.Итак, давайте посмотрим на это в коде:
def get_max_indices(groups):
"""
This function takes a list of lists and
returns the indices of the maximum elements
"""
maximum = 0
max_length = 0
total_elements = 0
length_before = 0
for idx, sub_group in enumerate(groups):
summation = sum(sub_group)
if summation > maximum:
maximum = summation
max_length = len(sub_group)
length_before = total_elements
total_elements += len(sub_group)
return [_ for _ in range(length_before, length_before+max_length)]
Теперь давайте попробуем оба:
>>> lst = [1.3, 0, 0.6, 0.7, 0.8]
>>> groups = split_list(lst, value=0)
>>> print(get_max_indices(groups))
[2, 3, 4]
Давайте попробуем другой пример:
>>> lst = [1, 2, 3, 0, 6, 9, 0, 10]
>>> groups = split_list(lst, value=0)
>>> print(get_max_indices(groups))
[4, 5]
Я надеюсь, что эторешает ваши вопросы.Я знаю, что это немного сложнее, чем вы думаете, но это поможет вам начать.Я думаю, что это можно оптимизировать и очистить, но я оставлю это вам.