Вот один из способов использования groupby
, но сначала вам нужно определить функцию, выполняющую то, что вы ищете в каждой группе. Чтобы объяснить идею, давайте рассмотрим простой фрейм данных dfs = pd.DataFrame({'a':[1,2,3,4,14,20,30,31]})
Я давно пытался решить эту проблему, пытаясь избежать зацикливания, и это кажется сложным. Вот идея, которую я заканчиваю. В numpy вы можете использовать substract
в сочетании с outer
, чтобы получить все различия между каждым элементом один к одному
print (np.subtract.outer(dfs.a, dfs.a))
array([[ 0, -1, -2, -3, -13, -19, -29, -30],
[ 1, 0, -1, -2, -12, -18, -28, -29],
[ 2, 1, 0, -1, -11, -17, -27, -28],
[ 3, 2, 1, 0, -10, -16, -26, -27],
[ 13, 12, 11, 10, 0, -6, -16, -17],
[ 19, 18, 17, 16, 6, 0, -10, -11],
[ 29, 28, 27, 26, 16, 10, 0, -1],
[ 30, 29, 28, 27, 17, 11, 1, 0]], dtype=int64)
Теперь, например, в column 0
вы можете видеть, что разница >10
начинается с row 4
, затем идет к column 4
, разница >10
начинается с row 6
и продолжается до column 6
вас. не получить разницу достаточно большой. Таким образом, фильтрация будет хранить строки 0, 4 и 6, что соответствует значениям [1,14,30]. Чтобы получить эти числа, вы можете сравнить np.substract.outer
с 10 и sum
с axis=0
, такими как:
arr = (np.subtract.outer(dfs.a, dfs.a) <=10).sum(0)
print (arr)
array([4, 4, 4, 5, 6, 7, 8, 8])
Теперь вы видите, arr[0] = 4
, затем arr[4] = 6
, затем arr[6]=8
в этом примере выходит за пределы, поэтому остановитесь. Один из способов поймать это число - использовать while
(если у кого-то есть решение numpy
, мне это интересно)
list_ind = [0] # initialize list of index to keep with 0
arr = (np.subtract.outer(dfs.a, dfs.a) <=10).sum(0)
i = arr[0]
while i < len(arr):
list_ind.append(i)
i = arr[i]
print (list_ind)
[0, 4, 6]
print (dfs.iloc[list_ind])
a
0 1
4 14
6 30
Теперь со всей проблемой и groupby
вы можете сделать:
# it seems you need to convert the column frame_no to integer
df['frame_int'] = pd.to_numeric(df['frame_no'])
df = df.sort_values('frame_int') #ensure data to be sorted by frame_int, whatever the global_id
#define the function looking for the ind
def find_ind (df_g):
list_ind = [0]
arr = (np.subtract.outer(df_g.frame_int, df_g.frame_int) <= 10).sum(0)
i = arr[0]
while i <len(arr):
list_ind.append(i)
i = arr[i]
return df_g.iloc[list_ind]
#create the filtered dataframe
df_filtered = (df.groupby('global_id').apply(find_ind)
.drop('frame_int',axis=1).reset_index(drop=True))
print (df_filtered)
seq_name label pedestrian_id frame_no global_id
0 1 crossing 1 1 1
1 1 crossing 2 1 2
2 1 crossing 2 12 2
3 1 crossing 2 29 2
4 2 crossing 1 34 3
5 2 crossing 1 49 3
Если вы хотите сохранить индекс исходных строк, вместо него вы можете добавить level=0
в reset_index
, например reset_index(level=0,drop=True)
.