interpolate.griddata использует только одно ядро - PullRequest
0 голосов
/ 07 сентября 2018

Я интерполирую двумерный массив Numpy, чтобы заполнить пропущенные значения, отмеченные NaN. Следующий код работает, но использует только одно ядро. Есть ли лучшие функции, которые я могу использовать, чтобы использовать все 24 ядра, которые у меня есть?

x = np.arange(0, array.shape[1])
y = np.arange(0, array.shape[0])
#mask invalid values
array = np.ma.masked_invalid(array)
xx, yy = np.meshgrid(x, y)
#get only the valid values
x1 = xx[~array.mask]
y1 = yy[~array.mask]
newarr = array[~array.mask]

GD1 = interpolate.griddata((x1, y1), newarr.ravel(),
                      (xx, yy),
                         method='cubic')

1 Ответ

0 голосов
/ 08 сентября 2018

Я думаю, что вы можете сделать это с dask . Я не слишком знаком с Dask, но вот начало:

import numpy as np
from scipy import interpolate
import dask.array as da
import matplotlib.pyplot as plt
from dask import delayed

# create data with random missing entries
ar_size = 2000
chunk_size = 500
z_array = np.ones((ar_size, ar_size))
z_array[np.random.randint(0, ar_size-1, 50),
      np.random.randint(0, ar_size-1, 50)]= np.nan

# XY coords
x = np.linspace(0, 3, z_array.shape[1])
y = np.linspace(0, 3, z_array.shape[0])

# gen sin wave for testing
z_array = z_array * np.sin(x)
# prove there are nans in the dataset
assert np.isnan(np.sum(z_array))

xx, yy = np.meshgrid(x, y)
print("global x.size: ", xx.size)

# make dask arrays
dask_xyz = da.from_array((xx, yy, z_array), chunks=(3, chunk_size, "auto"), name="dask_all")
dask_xx = dask_xyz[0,:,:]
dask_yy = dask_xyz[1,:,:]
dask_zz = dask_xyz[2,:,:]

# select only valid values
dask_valid_y1 = dask_yy[~da.isnan(dask_zz)]
dask_valid_x1 = dask_xx[~da.isnan(dask_zz)]
dask_newarr = dask_zz[~da.isnan(dask_zz)]

def gd_wrapped(x1, y1, newarr, xx, yy):
    # note: linear and cubic griddata impl do not extrapolate
    # and therefore fail near the boundaries... see RBF interp instead
    print("local x.size: ", x1.size)
    gd_zz = interpolate.griddata((x1, y1), newarr.ravel(),
                               (xx, yy),
                               method='nearest')
    return gd_zz

def rbf_wrapped(x1, y1, newarr, xx, yy):
    rbf_interpolant = interpolate.Rbf(x1, y1, newarr, function='linear')
    return rbf_interpolant(xx, yy)

# interpolate
# gd_chunked = [delayed(rbf_wrapped)(x1, y1, newarr, xx, yy) for \
gd_chunked = [delayed(gd_wrapped)(x1, y1, newarr, xx, yy) for \
            x1, y1, newarr, xx, yy \
            in \
            zip(dask_valid_x1.to_delayed().flatten(),
                dask_valid_y1.to_delayed().flatten(),
                dask_newarr.to_delayed().flatten(),
                dask_xx.to_delayed().flatten(),
                dask_yy.to_delayed().flatten())]
gd_out = delayed(da.concatenate)(gd_chunked, axis=0)
gd_out.visualize("dask_par.png")
gd1 = np.array(gd_out.compute())
print(gd1)
assert gd1.shape == (ar_size, ar_size)
print(gd1.shape)
plt.figure()
plt.imshow(gd1)
plt.savefig("dask_par_sin.png")

# prove we have no more nans in the data
assert ~np.isnan(np.sum(gd1))

Есть некоторые проблемы с этой реализацией. Griddata не может экстраполировать, поэтому nans является проблемой на границах чанка. Вы могли бы решить эту проблему с помощью нескольких перекрывающихся ячеек. В качестве временного решения вы можете использовать method = 'near' или попробовать интерполяция радиальной базисной функции .

...