Почему dask throw IndexError: индекс 1 выходит за пределы оси 0 с размером 1, когда я использую apply_along_axis? - PullRequest
0 голосов
/ 26 мая 2020

Я пытаюсь использовать dask.array.apply_along_axis для 2D-массива. Однако мой массив представляет собой массив dask, он всегда выдает исключение, подобное следующему:

Traceback (most recent call last):
  File "D:/test/apply_along_axis_test.py", line 22, in <module>
    b = da.apply_along_axis(lambda  a: a[index_array], 1, source_array)
  File "D:\Program Files\Python3\lib\site-packages\dask\array\routines.py", line 383, in apply_along_axis
    test_result = np.array(func1d(test_data, *args, **kwargs))
  File "D:/test/apply_along_axis_test.py", line 22, in <lambda>
    b = da.apply_along_axis(lambda  a: a[index_array], 1, source_array)
IndexError: index 1 is out of bounds for axis 0 with size 1

Однако, когда я применяю этот метод к numpy.array. Он может работать успешно.

Пример кода выглядит следующим образом:

source_array = np.random.randint(0, 10, (2, 4))
index_array = np.asarray([[0, 0], [1, 0], [2, 1], [3, 2]])

b = np.apply_along_axis(lambda a: a[index_array], 1, source_array)
print(b)

source_array = da.from_array(source_array)
b = da.apply_along_axis(lambda  a: a[index_array], 1, source_array)

Я могу успешно напечатать b. Однако последняя строка кода вызовет исключение. Я думаю, что мне следует использовать какой-нибудь метод карты, например map_partitions. Однако я не могу найти такой метод в dask.array.

1 Ответ

1 голос
/ 28 мая 2020

Я думаю, это следует решить, определив shape и dtype. Вы можете сделать это вручную или использовать make_meta, чтобы определить, что это должно быть:

In [55]: from dask.dataframe.utils import make_meta

In [56]: da.apply_along_axis(lambda  a: a[index_array], 1, source_array,
    ...:                     shape=make_meta(source_array).shape,
    ...:                     dtype=make_meta(source_array).dtype).compute()
Out[56]:
array([[[2, 2],
        [1, 2],
        [1, 1],
        [6, 1]],

       [[1, 1],
        [6, 1],
        [9, 6],
        [3, 9]]])

Вы также не первый, кто столкнулся с этой проблемой: https://github.com/dask/dask/issues/3727

...