Проблема с производительностью при groupby.shift - PullRequest
4 голосов
/ 08 мая 2019

Тестовый код:

SIZE_MULT = 5
data = np.random.randint(0, 255, size=10**SIZE_MULT, dtype='uint8')
index = pd.MultiIndex.from_product(
            [list(range(10**(SIZE_MULT-1))), list('ABCDEFGHIJ')],
            names = ['d', 'l'])        
test = pd.DataFrame(data, index, columns = ['data'])
test.head()
test['data'].dtype

выход

        data
d   l   
0   A   137
    B   156
    C   48
    D   186
    E   170

dtype('uint8')

И предположим, что мы хотим сгруппировать по 0 уровням индекса и сместить каждую группу (например, шаг смещения = 2).

%%time
shifted = test.groupby(axis=0, level=[0]).shift(2)
print(shifted['data'].dtype)

Выход:

float64
CPU times: user 9.43 ms, sys: 56 µs, total: 9.49 ms
Wall time: 8.29 ms

Теперь к проблеме: если мы хотим сохранить наш dtype 'uint8', мы должны избавиться от None s и установить наше значение заполнения, например, на 0. Но сейчас мы получим ОГРОМНОЕ время выполнения кода:

%%time
shifted = test.groupby(axis=0, level=[0]).shift(2, fill_value = 0)
shifted.head()
print(shifted['data'].dtype)

Выход:

uint8
CPU times: user 5.9 s, sys: 38.4 ms, total: 5.94 s
Wall time: 5.89 s

Итак, вопрос в том, почему это так долго? Если мы возьмем 1-й сдвинутый кадр данных без fill_value и добавим несколько строк кода для достижения того же результата:

%%time
shifted = test.groupby(axis=0, level=[0]).shift(2)
shifted.fillna(0, inplace=True)
shifted = shifted.astype(np.uint8)
print(shifted['data'].dtype)

Выход:

uint8
CPU times: user 9.64 ms, sys: 3.68 ms, total: 13.3 ms
Wall time: 11.3 ms

Это добавит всего несколько мс, а не 5 секунд.

РЕДАКТИРОВАТЬ: соответствующий GitHub выпуск

1 Ответ

1 голос
/ 02 июня 2019

Исходный код состоит в том, что с указанным значением заполнения используется медленный вызов применения.Без значения заполнения он может использовать гораздо более быстрый цитонизированный результат:

Код по ссылке:

def shift(self, periods=1, freq=None, axis=0, fill_value=None):
    #...
    if freq is not None or axis != 0 or not isna(fill_value):
        return self.apply(lambda x: x.shift(periods, freq,
                                            axis, fill_value))

    return self._get_cythonized_result('group_shift_indexer',
                                       self.grouper, cython_dtype=np.int64,
                                       needs_ngroups=True,
                                       result_is_index=True,
                                       periods=periods)

Так что в этом случае я бы использовал .fillna() после.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...