Создание типа данных numpy из структуры Cython - PullRequest
1 голос
/ 18 июня 2020

Ниже приведен фрагмент кода Cython, который в настоящее время используется в двоичных деревьях scikit-learn,

  # Some compound datatypes used below:
  cdef struct NodeHeapData_t:
      DTYPE_t val
      ITYPE_t i1
      ITYPE_t i2

  # build the corresponding numpy dtype for NodeHeapData
  cdef NodeHeapData_t nhd_tmp
  NodeHeapData = np.asarray(<NodeHeapData_t[:1]>(&nhd_tmp)).dtype

(полный исходный код здесь )

Последняя строка создает numpy dtype из этой структуры Cython. Мне не удалось найти много документации по этому поводу, и, в частности, я не понимаю, зачем нужна нарезка [:1] и для чего она нужна. Дополнительное обсуждение можно найти в scikit-learn # 17228 . Есть ли у кого-нибудь идеи по этому поводу?

1 Ответ

2 голосов
/ 18 июня 2020

Это умный, но запутанный трюк!

Следующий код создает cython-array длины 1, потому что память, которую он использует (но не владеет!), Имеет ровно один элемент .

cdef NodeHeapData_t nhd_tmp
<NodeHeapData_t[:1]>(&nhd_tmp)

Теперь cython-array реализует протокол буфера и, таким образом, Cython имеет механизм для создания format -строки, которая описывает тип элемента, который он содержит.

np.asarray также использует протокол буфера и может создавать dtype -объект из format -строки, которая предоставляется массивом cython.

Вы можете увидеть format-string через:

%%cython
import numpy as np

# Some compound datatypes used below:
cdef struct NodeHeapData_t:
  double val
  int i1
  int i2

# build the corresponding numpy dtype for NodeHeapData
cdef NodeHeapData_t nhd_tmp
NodeHeapData = np.asarray(<NodeHeapData_t[:1]>(&nhd_tmp)).dtype

print("format string:",memoryview(<NodeHeapData_t[:1]>(&nhd_tmp)).format)
print(NodeHeapData )

, что приводит к

format string: T{d:val:i:i1:i:i2:}
[('val', '<f8'), ('i1', '<i4'), ('i2', '<i4')]

Я не могу придумать менее запутанного решения, кроме создания dtype - объект вручную - что может показаться некрасивым для некоторых типов данных на разных платформах *, но должно быть прямым в большинстве случаев.


*) np.int - такой проблемный случай c. Легко упустить из виду, что np.int отображается на long, а не int (сбивает с толку, не так ли?).

Например,

memoryview(np.zeros(1, dtype=np.int)).itemsize

оценивается как

  • На Windows: 4 (размер long в байтах на Windows).
  • На Linux: 8 (размер long в байтах на Linux ).
...