У меня есть 2 матрицы (data
и result
), result
это 2D dask.array
и data
это 3D xarray.DataArray
, я должен сделать калькуляцию следующим образом:
var_idx = 0 # const value
result[i,j] = external_function(data[var_idx,:,i],data[var_idx,:,j])
Я новичок в python и dask, но после нескольких недель обучения приведенный ниже код представляет то, что я мог бы сделать, чтобы сделать это calc ...
def func_block(block,block_info=None):
@jit(nogil=True)
def func_external(a,b):
# just an example
return np.max(np.multiply(a, b))
# result element location of data array
# to compute just lower triangular elements
block_info = block_info[0]
[(s_row,end_row),(s_col,end_col)]= block_info['array-location']
if s_col > s_row:
return block
it = np.nditer(block, flags=['multi_index'],op_flags=['readwrite'])
while not it.finished:
(r_idx,c_idx) = it.multi_index
row = r_idx+s_row
col = c_idx+s_col
if row > col:
it[0] = float(func_external(da_stacked_notnull[0,:,row],da_stacked_notnull[0,:,col]))
it.iternext()
return block
darr_zeros = da.zeros((grid_size,grid_size), chunks=(3000,3000))
darr_zeros = darr_zeros -1
dask_result = darr_zeros.map_blocks(func_block,chunks=(3000,3000),dtype=np.float16)
dask_result = xr.DataArray(dask_result.compute())
Однако у меня есть некоторые проблемы:
1) Если я использую все необходимые данные, я получаю ошибку памяти Python.Я полагаю, что из-за dask_result.compute()
, если я правильно понимаю, .compute()
возвращает пустой массив, но у меня недостаточно памяти для хранения всех результатов в массивном массиве. Как я могу сделать это, используя массив dask?
2) Во всех потоках используется до 50% каждого ядра ... и я думаю, что это связано с GIL, но dask не улучшит его?Можно ли реорганизовать это, чтобы иметь лучшую производительность?
Это матрица данных без всех точек сетки, которые мне действительно нужны:
da_stacked_notnull
Out[1]:
<xarray.DataArray (variable: 1, time: 365, gridcell: 7230)>
array([[[-0.376704, -0.036332, ..., 27.715254, 26.863554],
[-0.465122, -0.152866, ..., 27.227764, 26.556808],
...,
[-0.724707, -0.520708, ..., 29.315022, 29.10007 ],
[-0.835325, -0.704899, ..., 29.425072, 29.086765]]], dtype=float32)
Coordinates:
* time (time) datetime64[ns] 2015-01-01 2015-01-02 ... 2015-12-31
* gridcell (gridcell) MultiIndex
- lon (gridcell) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0
- lat (gridcell) float64 -59.5 -58.5 -57.5 -56.5 ... -32.5 -31.5 -30.5
* variable (variable) <U3 'sst'
Заранее спасибо!