Как сохранить все ряды сгруппированных панд DataFrage, отвечающих определенным критериям? - PullRequest
1 голос
/ 02 июля 2019

Для pandas DataFrame с группами я хочу сохранить все строки до первого появления определенного значения (и отбросить все остальные строки).

MWE:

import pandas as pd
df = pd.DataFrame({'A' : ['foo', 'foo', 'foo', 'bar', 'bar', 'bar', 'tmp'],
                   'B' : [0, 1, 0, 0, 0, 1, 0],
                   'C' : [2.0, 5., 8., 1., 2., 9., 7.]})

дает

    A    B  C
0   foo  0  2.0
1   foo  1  5.0
2   foo  0  8.0
3   bar  0  1.0
4   bar  0  2.0
5   bar  1  9.0
6   tmp  0  7.0

и я хочу сохранить все строки для каждой группы (A - переменная группировки) до B == 1 (включая эту строку). Итак, мой желаемый результат -

    A    B  C
0   foo  0  2.0
1   foo  1  5.0
3   bar  0  1.0
4   bar  0  2.0
5   bar  1  9.0
6   tmp  0  7.0

Как сохранить все строки сгруппированного DataFrage, удовлетворяющие определенным критериям?

Я нашел , как отбрасывать определенные группы, не отвечающие определенным критериям (и сохраняя все остальные строки всех других групп) , но не как отбрасывать определенные строки для всех групп. Самым быстрым, что я получил, было получение индексов строк в каждой группе, которые я хочу сохранить:

df.groupby('A').apply(lambda x: x['B'].cumsum().searchsorted(1))

в результате

A
bar    2
foo    1
tmp    1

Этого недостаточно, поскольку он не возвращает фактические данные (и может быть лучше, если для tmp результат будет 0)

1 Ответ

1 голос
/ 02 июля 2019

Прочитав этот вопрос о разнице между groupby.apply и groupby.aggregate, я понял, что apply работает со всеми столбцами и строками (то есть с DataFrame?) Группы.Так что это моя функция, которая должна применяться к каждой группе:

def f(group):
    index = min(group['B'].cumsum().searchsorted(1), len(group))
    return group.iloc[0:index+1]

Запустив df.groupby('A').apply(f), я получаю желаемый результат:

            A       B   C
A               
bar     3   bar     0   1.0
        4   bar     0   2.0
        5   bar     1   9.0
foo     0   foo     0   2.0
        1   foo     1   5.0
tmp     6   tmp     0   7.0
...