Я новичок в Dask и все еще пытаюсь понять, как это сделать гладко. Я экспериментировал с future
API и получил несколько удивительных результатов.
У меня есть простой while
l oop в моем коде, который вызывает функцию cpi5
. Когда я %timeit
выполняю функцию, я получаю:
%timeit min(cpci5(x,N,M,n,lciw,uciw))
6.74 s ± 178 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Теперь я запускаю ту же функцию, но использую распределенный шедулер в dask, и получаю это:
from dask.distributed import Client
client= Client()
%timeit B.append(client.submit(cpci5,x,N,M,n,lciw,uciw)); bb = np.array(client.gather(B))
29 ms ± 6.01 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
Пока все хорошо, ожидаемое улучшение есть, но когда я определяю время выполнения самого l oop, где вызывается функция, я почти не вижу различий, и они оба запускаются примерно через 19 с. (У меня есть профиль начального l oop, и более 90% времени вычислений связано с функцией, поэтому улучшение должно быть.)
Результаты последовательны и идентичны.
Что может вызвать такую разницу?
Ниже вы найдете соответствующий фрагмент кода. PS: я уже пошел так далеко, как мог в оптимизации кода, но в моем случае этого недостаточно.
N = 392
n = 326
x = np.arange(0,n+1)
if np.floor(n/2) == n/2:
xvalue = int(n/2 +1)
else :
xvalue = int((n+1)/2)
aa = np.arange(lciw[xvalue-1],np.floor(N/2)).astype(int)
lciw :
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 21, 22, 23, 24, 25, 26,
27, 28, 29, 31, 32, 33, 34, 35, 36, 37, 38, 40, 41,
42, 43, 44, 45, 46, 47, 49, 50, 51, 52, 53, 54, 55,
57, 58, 59, 60, 61, 62, 64, 65, 66, 67, 68, 69, 70,
72, 73, 74, 75, 76, 77, 79, 80, 81, 82, 83, 84, 86,
87, 88, 89, 90, 91, 93, 94, 95, 96, 97, 98, 100, 101,
102, 103, 104, 105, 107, 108, 109, 110, 111, 112, 114, 115, 116,
117, 118, 120, 121, 122, 123, 124, 125, 127, 128, 129, 130, 131,
132, 134, 135, 136, 137, 138, 140, 141, 142, 143, 144, 146, 147,
148, 149, 150, 151, 153, 154, 155, 156, 157, 159, 160, 161, 162,
163, 165, 166, 167, 168, 169, 170, 172, 173, 174, 175, 176, 178,
179, 180, 181, 182, 184, 185, 186, 189, 188, 190, 191, 192, 193,
194, 196, 197, 198, 199, 200, 202, 203, 204, 205, 206, 208, 209,
210, 211, 212, 214, 215, 216, 217, 218, 220, 221, 222, 223, 225,
226, 227, 228, 229, 231, 232, 233, 234, 235, 237, 238, 239, 240,
241, 243, 244, 245, 246, 248, 249, 250, 251, 252, 254, 255, 256,
257, 258, 260, 261, 262, 263, 265, 266, 267, 268, 269, 271, 272,
273, 274, 276, 277, 278, 279, 280, 282, 283, 284, 285, 287, 288,
289, 290, 292, 293, 294, 295, 296, 298, 299, 300, 301, 303, 304,
305, 306, 308, 309, 310, 311, 313, 314, 315, 316, 317, 319, 320,
321, 322, 324, 325, 326, 327, 329, 330, 331, 332, 334, 335, 336,
338, 339, 340, 341, 343, 344, 345, 346, 348, 349, 350, 351, 353,
354, 355, 357, 358, 359, 360, 362, 363, 364, 366, 367, 368, 369,
371, 372, 373, 375, 376, 377, 379, 380, 382, 383, 384, 386, 387,
389, 390])
uciw :
array([2, 3, 5, 6, 8, 9, 10, 12, 13, 15, 16, 17, 19,
20, 21, 23, 24, 25, 26, 28, 29, 30, 32, 33, 34, 35,
37, 38, 39, 41, 42, 43, 44, 46, 47, 48, 49, 51, 52,
53, 54, 56, 57, 58, 60, 61, 62, 63, 65, 66, 67, 68,
70, 71, 72, 73, 75, 76, 77, 78, 79, 81, 82, 83, 84,
86, 87, 88, 89, 91, 92, 93, 94, 96, 97, 98, 99, 100,
102, 103, 104, 105, 107, 108, 109, 110, 112, 113, 114, 115, 116,
118, 119, 120, 121, 123, 124, 125, 126, 127, 129, 130, 131, 132,
134, 135, 136, 137, 138, 140, 141, 142, 143, 144, 146, 147, 148,
149, 151, 152, 153, 154, 155, 157, 158, 159, 160, 161, 163, 164,
165, 166, 167, 169, 170, 171, 172, 174, 175, 176, 177, 178, 180,
181, 182, 183, 184, 186, 187, 188, 189, 190, 192, 193, 194, 195,
196, 198, 199, 200, 201, 202, 204, 203, 206, 207, 208, 210, 211,
212, 213, 214, 216, 217, 218, 219, 220, 222, 223, 224, 225, 226,
227, 229, 230, 231, 232, 233, 235, 236, 237, 238, 239, 241, 242,
243, 244, 245, 246, 248, 249, 250, 251, 252, 254, 255, 256, 257,
258, 260, 261, 262, 263, 264, 265, 267, 268, 269, 270, 271, 272,
274, 275, 276, 277, 278, 280, 281, 282, 283, 284, 285, 287, 288,
289, 290, 291, 292, 294, 295, 296, 297, 298, 299, 301, 302, 303,
304, 305, 306, 308, 309, 310, 311, 312, 313, 315, 316, 317, 318,
319, 320, 322, 323, 324, 325, 326, 327, 328, 330, 331, 332, 333,
334, 335, 337, 338, 339, 340, 341, 342, 343, 345, 346, 347, 348,
349, 350, 351, 352, 354, 355, 356, 357, 358, 359, 360, 361, 363,
364, 365, 366, 367, 368, 369, 370, 371, 373, 374, 375, 376, 377,
378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390,
391, 392])
def cpci5(x,N,M,n,lciw,uciw):
f = np.vectorize(hypergeom.pmf)
idd = np.vectorize(ind)
X, m = np.meshgrid(x,M)
kk = idd(m,lciw,uciw) * f(X, N, m, n) #idd just implement a test lciw<= m <=uciw
return min(pd.Series(kk.sum(axis=1)))
M = np.arange(0,N+1) # Initial implementation of the function
ii = 0
while (ii <len(aa)+1):
lciw[xvalue-1] = aa[ii]
uciw[xvalue-1] = N - aa[ii]
bb = min(cpci5(x,N,M,n,lciw,uciw))
if bb >= 1-alpha:
ii1 = ii
ii += 1
else :
ii = len(aa)+1
lciw[xvalue-1] = aa[ii1]
uciw[xvalue-1] = N - lciw[xvalue-1]
M = np.arange(0,N+1) # Distributed version
ii = 0
B = []
while (ii <len(aa)):
lciw[xvalue-1] = aa[ii]
uciw[xvalue-1] = N - aa[ii]
B.append(client.submit(cpci5,x,N,M,n,lciw,uciw))
ii += 1
bb = np.array(client.gather(B))
ii1 = len(bb[bb>1-alpha])-1
lciw[xvalue-1] = aa[ii1]
uciw[xvalue-1] = N - lciw[xvalue-1]