Я пытаюсь реструктурировать свой код, чтобы использовать Dask вместо NumPy для вычислений с большим массивом.Однако я борюсь с производительностью Dask во время выполнения:
In[15]: import numpy as np
In[16]: import dask.array as da
In[17]: np_arr = np.random.rand(10, 10000, 10000)
In[18]: da_arr = da.from_array(np_arr, chunks=(-1, 'auto', 'auto'))
In[19]: %timeit np.mean(np_arr, axis=0)
1 loop, best of 3: 2.59 s per loop
In[20]: %timeit da_arr.mean(axis=0).compute()
1 loop, best of 3: 4.23 s per loop
Я смотрел на похожие вопросы ( почему точечный продукт в dask медленнее, чем в numpy ), но играю вокругс размером куска не помогло.В основном я буду использовать массивы примерно того же размера, что и выше.Рекомендуется ли использовать NumPy вместо Dask для таких массивов или я могу что-то настроить?Я также попытался использовать Client
из dask.distributed
и запустил его с 16 процессами и 4 потоками на процесс (16-ядерный процессор), но это сделало его еще хуже.Заранее спасибо!
РЕДАКТИРОВАТЬ: Я немного поиграл с Dask и распределенной обработки.Передача данных (сброс массива и получение результата), по-видимому, является основным ограничением / проблемой, тогда как вычисления действительно быстрые (436 мс по сравнению с 9,51 с).Но даже для client.compute()
время стены больше (12,1 с), чем для do_stuff(data)
.Можно ли как-то улучшить это и передачу данных вообще?
In[3]: import numpy as np
In[4]: from dask.distributed import Client, wait
In[5]: from dask import delayed
In[6]: import dask.array as da
In[7]: client = Client('address:port')
In[8]: client
Out[8]: <Client: scheduler='tcp://address:port' processes=4 cores=16>
In[9]: data = np.random.rand(400, 100, 10000)
In[10]: %time [future] = client.scatter([data])
CPU times: user 8.36 s, sys: 5.08 s, total: 13.4 s
Wall time: 24.5 s
In[11]: x = da.from_delayed(delayed(future), shape=data.shape, dtype=data.dtype)
In[12]: x = x.rechunk(chunks=('auto', 'auto', 'auto'))
In[13]: x = client.persist(x)
In[14]: {w: len(keys) for w, keys in client.has_what().items()}
Out[14]:
{'tcp://address:port': 65,
'tcp://address:port': 0,
'tcp://address:port': 0,
'tcp://address:port': 0}
In[15]: client.rebalance(x)
In[16]: {w: len(keys) for w, keys in client.has_what().items()}
Out[16]:
{'tcp://address:port': 17,
'tcp://address:port': 16,
'tcp://address:port': 16,
'tcp://address:port': 16}
In[17]: def do_stuff(arr):
... arr = arr/3. + arr**2 - arr**(1/2)
... arr[arr >= 0.5] = 1
... return arr
...
In[18]: %time future_compute = client.compute(do_stuff(x)); wait(future_compute)
Matplotlib support failed
CPU times: user 387 ms, sys: 49.5 ms, total: 436 ms
Wall time: 12.1 s
In[19]: future_compute
Out[19]: <Future: status: finished, type: ndarray, key: finalize-54eb04bbe03eee8af686fd43b41eb161>
In[21]: %timeit future_compute.result()
1 loop, best of 3: 19.4 s per loop
In[21]: %time do_stuff(data)
CPU times: user 4.49 s, sys: 5.02 s, total: 9.51 s
Wall time: 9.5 s