Я решил эту проблему, написав функцию, которую я назвал groupby_many
. Работает на объектах Series
и DataFrame
:
import numpy as np
import pandas as pd
def groupby_many(data, groups):
"""
Groups a Series or DataFrame object where each row can belong to many groups.
Parameters
----------
data : Series or DataFrame
The data to group
groups : iterable of iterables
For each row in data, the groups that row belongs to.
A row can belong to zero, one, or multiple groups.
Returns
-------
A GroupBy object
"""
pairs = [(i, g) for (i, gg) in enumerate(groups) for g in gg]
row, group = zip(*pairs)
return data.iloc[list(row)].groupby(list(group))
Работает путем создания версии данных, в которой каждая строка дублируется n раз, где n - количество групп, к которым принадлежит строка. Каждая строка в этой версии принадлежит только одной группе, поэтому теперь она может обрабатываться обычным groupby
.
Чтобы увидеть это в действии на примере данных в вопросе:
>>> df = pd.DataFrame.from_dict({
'Item': ["Apple", "Tomato", "Potato", "Orange"],
'Count': [5, 10, 3, 20],
'Tags': [['fruit', 'red'], ['vegetable', 'red'], [], ['fruit']]})
>>> df = df.set_index('Item')
>>> print(df)
Count Tags
Item
Apple 5 [fruit, red]
Tomato 10 [vegetable, red]
Potato 3 []
Orange 20 [fruit]
>>> result = groupby_many(df, df['Tags']).sum()
>>> print(result)
Count
fruit 25
red 15
vegetable 10