Азимутальное суммирование многомерного массива с использованием dask.map_blocks - PullRequest
0 голосов
/ 31 августа 2018

Я пытаюсь распараллелить азимутальное суммирование массива dask, используя map_blocks API. У меня есть следующий код:

import numpy as np
import xarray as xr
import dask.array as dsar
from dask.diagnostics import ProgressBar

N = 2**7
da = xr.DataArray(np.linspace(0.,1.,N*2)[:,np.newaxis,np.newaxis]
                 +(np.arange(-N/2,N/2)**2)[np.newaxis,:,np.newaxis]
                 +(np.arange(-N/2,N/2)**2)[np.newaxis,np.newaxis,:], 
                 dims=['time','x','y'],
                 coords={'time':range(N*2),'x':np.arange(-N/2,N/2),'y':np.arange(-N/2,N/2)}
                 ).chunk({'time':1})

R = np.sqrt(da.x.data[np.newaxis,:]**2 + da.y.data[:,np.newaxis]**2)
nfactor = 4
nbins = int(len(R)/nfactor)
r = np.linspace(0.,da.x.data.max(),nbins)
ridx = np.digitize(np.ravel(R), r)
area = np.bincount(ridx)

axes=[-1,-2]
M = [da.shape[n] for n in axes]
data = da.data.reshape((N*2,np.prod(M)))
with ProgressBar():
     azisum = dsar.map_blocks(np.bincount, xr.DataArray(ridx*np.ones(N*2)[:,np.newaxis], 
                                                  dims=['time','points'], 
                                                  coords={'time':da.time.data,'points':range(np.prod(M))}
                                                  ).chunk({'time':1}).data, 
                             weights=data, chunks=(N*2,nbins), dtype=da.dtype
                             ).compute()

Я думал, что это сработает, поскольку x и weights в numpy.bincount имеют одинаковую форму, но я получаю следующую ошибку:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-7-afd3fe2d4880> in <module>()
     10                                                       coords={'time':da.time.data,'points':range(np.prod(M))}
     11                                                       ).chunk({'time':1}).data, 
---> 12                             weights=data, chunks=(N*2,nbins), dtype=da.dtype
     13                             ).compute()
     14 azisum

/Users/uchidatakaya/anaconda/envs/xrft/lib/python3.6/site-packages/dask/base.py in compute(self, **kwargs)
    154         dask.base.compute
    155         """
--> 156         (result,) = compute(self, traverse=False, **kwargs)
    157         return result
    158 

/Users/uchidatakaya/anaconda/envs/xrft/lib/python3.6/site-packages/dask/base.py in compute(*args, **kwargs)
    393     keys = [x.__dask_keys__() for x in collections]
    394     postcomputes = [x.__dask_postcompute__() for x in collections]
--> 395     results = schedule(dsk, keys, **kwargs)
    396     return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
    397 

/Users/uchidatakaya/anaconda/envs/xrft/lib/python3.6/site-packages/dask/threaded.py in get(dsk, result, cache, num_workers, **kwargs)
     73     results = get_async(pool.apply_async, len(pool._pool), dsk, result,
     74                         cache=cache, get_id=_thread_get_id,
---> 75                         pack_exception=pack_exception, **kwargs)
     76 
     77     # Cleanup pools associated to dead threads

/Users/uchidatakaya/anaconda/envs/xrft/lib/python3.6/site-packages/dask/local.py in get_async(apply_async, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, **kwargs)
    499                         _execute_task(task, data)  # Re-execute locally
    500                     else:
--> 501                         raise_exception(exc, tb)
    502                 res, worker_id = loads(res_info)
    503                 state['cache'][key] = res

/Users/uchidatakaya/anaconda/envs/xrft/lib/python3.6/site-packages/dask/compatibility.py in reraise(exc, tb)
    110         if exc.__traceback__ is not tb:
    111             raise exc.with_traceback(tb)
--> 112         raise exc
    113 
    114 else:

/Users/uchidatakaya/anaconda/envs/xrft/lib/python3.6/site-packages/dask/local.py in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
    270     try:
    271         task, data = loads(task_info)
--> 272         result = _execute_task(task, data)
    273         id = get_id()
    274         result = dumps((result, id))

/Users/uchidatakaya/anaconda/envs/xrft/lib/python3.6/site-packages/dask/local.py in _execute_task(arg, cache, dsk)
    251         func, args = arg[0], arg[1:]
    252         args2 = [_execute_task(a, cache) for a in args]
--> 253         return func(*args2)
    254     elif not ishashable(arg):
    255         return arg

/Users/uchidatakaya/anaconda/envs/xrft/lib/python3.6/site-packages/dask/compatibility.py in apply(func, args, kwargs)
     91     def apply(func, args, kwargs=None):
     92         if kwargs:
---> 93             return func(*args, **kwargs)
     94         else:
     95             return func(*args)

ValueError: object too deep for desired array

Может кто-нибудь сказать мне, почему я получаю сообщение об ошибке слишком глубокого объекта? Заранее спасибо за помощь!

...