Можно ли улучшить производительность использования xarray groupby / map? - PullRequest
0 голосов
/ 06 февраля 2020

Я новичок в использовании xarray / dask и пытаюсь нормализовать большое количество данных сонара, используя xarray. Я испробовал несколько подходов и, похоже, не смог заставить его работать даже на скорости, которую я ожидал бы возможной.

Есть как минимум две проблемы с производительностью, которые я видел: 1) Кажется, это не так распараллелить, используя dask по всем ядрам на моей машине для groupby / map 2) Для однопоточной операции это также очень медленно по сравнению с моими ожиданиями (и без использования полного CPU)

Расчет:

Операцию, которую я пытаюсь выполнить, я описал как нормализацию. Я имею в виду, что у меня есть данные сонара, где каждый временной интервал имеет 2000 возвращаемых отсчетов глубины, но диапазон глубин различен для каждого временного интервала. Я нормализую каждый временной интервал до согласованного диапазона глубины.

Так, например: * временной интервал t = 0, у меня есть 2000 выборок в диапазоне от глубины 0 м до 20 м, так что каждая ячейка выборки является возвращением эха для 10 см воды * временной интервал t = 1, у меня 2000 образцов в диапазоне глубин от 0 до 40 м, поэтому каждая ячейка для образцов является отражением эха для 20 см воды

Я хочу преобразовать их все на одинаковую глубину спектр. Это в основном одноуровневая операция сокращения + сдвиг (если не от 0 м). На данный момент, поскольку инструменты, по-видимому, не существуют, я использую интерполяцию вместо приведения вниз, что на данный момент достаточно.

У меня есть отдельный пример ниже (с использованием rand вместо реальных данных), где я делаю этот расчет ,

В вычислениях я использую groupby, чтобы дать один временной интервал, а затем карту, чтобы применить интерполяцию к этому временному интервалу.

Теоретически каждый временной интервал может быть рассчитан параллельно, поскольку они независимый.

Я сделал следующие наблюдения с кодом ниже: * Использование чанка замедляет его примерно с 5 мсек c за интервал времени до 70 мсек c за интервал времени * Появляется большая часть стоимости быть вызовами: ds..values.item ()

На моем оборудовании это занимает около 5 мсек c за интервал времени, и у меня есть миллионы интервалов времени, что неожиданно медленно для типа операции, которой я являюсь делает на 2000 образцов.

Есть ли лучший способ сделать это?

import sys
import math
import logging
import dask
import xarray
import numpy

logger = logging.getLogger('main')

if __name__ == '__main__':
    logging.basicConfig(
        stream=sys.stdout,
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=logging.INFO,
        datefmt='%Y-%m-%d %H:%M:%S')

    logger.info('Starting dask client')
    client = dask.distributed.Client()

    SIZE = 100000
    SONAR_BINS = 2000
    time = range(0, SIZE)
    upper_limit = numpy.random.randint(0, 10, (SIZE))
    lower_limit = numpy.random.randint(20, 30, (SIZE))
    sonar_data = numpy.random.randint(0, 255, (SIZE, SONAR_BINS))

    channel = xarray.Dataset({
            'upper_limit': (['time'], upper_limit, {'units': 'depth meters'}),
            'lower_limit': (['time'],  lower_limit, {'units': 'depth meters'}),
            'data': (['time', 'depth_bin'], sonar_data, {'units': 'amplitude'}),
        },
        coords={
            'depth_bin': (['depth_bin'], range(0,SONAR_BINS)),
            'time': (['time'], time)
        })

    logger.info('get overall min/max radar range we want to normalize to called the adjusted range')
    adjusted_min, adjusted_max = channel.upper_limit.min().values.item(), channel.lower_limit.max().values.item()
    adjusted_min = math.floor(adjusted_min)
    adjusted_max = math.ceil(adjusted_max)
    logger.info('adjusted_min: %s, adjusted_max: %s', adjusted_min, adjusted_max)

    bin_count = len(channel.depth_bin)
    logger.info('bin_count: %s', bin_count)

    adjusted_depth_per_bin = (adjusted_max - adjusted_min) / bin_count
    logger.info('adjusted_depth_per_bin: %s', adjusted_depth_per_bin)

    adjusted_bin_depths = [adjusted_min + (j * adjusted_depth_per_bin) for j in range(0, bin_count)]
    logger.info('adjusted_bin_depths[0]: %s ... [-1]: %s', adjusted_bin_depths[0], adjusted_bin_depths[-1])

    def Interp(ds):
        # Ideally instead of using interp we will use some kind of downsampling and shift
        # this doesnt exist in xarray though and interp is good enough for the moment

        # I just added this to debug
        t = ds.time.values.item()
        if (t % 100) == 0:
            total = len(channel.time)
            perc = 100.0 * t / total
            logger.info('%s : %s of %s', perc, t, total)

        unadjusted_depth_amplitudes = ds.data
        unadjusted_min = ds.upper_limit.values.item()
        unadjusted_max = ds.lower_limit.values.item()
        unadjusted_depth_per_bin = (unadjusted_max - unadjusted_min) / bin_count

        index_mapping = [((adjusted_min + (bin * adjusted_depth_per_bin)) - unadjusted_min) / unadjusted_depth_per_bin for bin in range(0, bin_count)]
        adjusted_depth_amplitudes = unadjusted_depth_amplitudes.interp(coords={'depth_bin':index_mapping}, method='linear', assume_sorted=True)
        adjusted_depth_amplitudes = adjusted_depth_amplitudes.rename({'depth_bin':'depth'}).assign_coords({'depth':adjusted_bin_depths})

        #logger.info('%s, \n\tunadjusted_depth_amplitudes.values:%s\n\tunadjusted_min:%s\n\tunadjusted_max:%s\n\tunadjusted_depth_per_bin:%s\n\tindex_mapping:%s\n\tadjusted_depth_amplitudes:%s\n\tadjusted_depth_amplitudes.values:%s\n\n', ds, unadjusted_depth_amplitudes.values, unadjusted_min, unadjusted_max, unadjusted_depth_per_bin, index_mapping, adjusted_depth_amplitudes, adjusted_depth_amplitudes.values)
        return adjusted_depth_amplitudes

    # Lets split into chunks so could be performed in parallel
    # This doesnt work to parallelize and only slows it down a lot
    #logger.info('chunk')
    #channel = channel.chunk({'time':100})

    logger.info('groupby')
    g = channel.groupby('time')

    logger.info('do interp')
    normalized_depth_data = g.map(Interp)

    logger.info('done')
...