Я пытаюсь создать пользовательский тип Numba. У меня проблемы с упаковкой и распаковкой Numba Numpy Arrays в Native Numpy Arrays.
Я искал в Интернете похожие проблемы и следовал примеру документации в меру своих возможностей. (https://numba.pydata.org/numba-doc/latest/extending/interval-example.html).
Я пытался интерпретировать (https://github.com/numba/numba/blob/master/numba/targets/boxing.py), но это довольно сложно. Поэтому, я думаю, я мог бы сделать что-то маленькое неправильно.
Нижемоя текущая попытка включения массива Numpy в мой пользовательский тип.
import numpy as np
from numba import types, cgutils
from numba.extending import typeof_impl, type_callable, models
from numba.extending import register_model, make_attribute_wrapper, overload_attribute
from numba.extending import lower_builtin, unbox, NativeValue, box
class BMatrix(object):
"""
A empty wrapper for a Binary Matrix
"""
def __init__(self, m, n, row_index):#, col_index):
self.m = m
self.n = n
self.row_index = row_index
# self.col_i = col_index
def __repr__(self):
return 'BMatrix(%d, %d)' % (self.m, self.n)
@property
def shape(self):
return (self.m, self.n)
class BMatrixType(types.Type):
def __init__(self):
super(BMatrixType, self).__init__(name='BMatrix')
bmatrix_type = BMatrixType()
@typeof_impl.register(BMatrix)
def typeof_index(val, c):
return bmatrix_type
@type_callable(BMatrix)
def type_bmatrix(context):
def typer(m, n, row_index):
if (isinstance(m, types.Integer)
and isinstance(n, types.Integer)
and isinstance(row_index, nb.types.Array)):
# and isinstance(col_index, nb.types.Array)):
return bmatrix_type
return typer
@register_model(BMatrixType)
class BMatrixModel(models.StructModel):
def __init__(self, dmm, fe_type):
members = [
('m', types.int64),
('n', types.int64),
('row_index', types.Array(types.int64, 1, 'C'))
]
models.StructModel.__init__(self, dmm, fe_type, members)
make_attribute_wrapper(BMatrixType, 'm', 'm')
make_attribute_wrapper(BMatrixType, 'n', 'n')
make_attribute_wrapper(BMatrixType, 'row_index', 'row_index')
@overload_attribute(BMatrixType, "shape")
def get_shape(bmatrix):
def getter(bmatrix):
return (bmatrix.m, bmatrix.n)
return getter
@lower_builtin(BMatrix, types.Integer, types.Integer, types.Array) #nb.types.Array, #nb.types.Array)
def impl_bmatrix(context, builder, sig, args):
typ = sig.return_type
m, n, row_index = args
bmatrix = cgutils.create_struct_proxy(typ)(context, builder)
bmatrix.m = m
bmatrix.n = n
bmatrix.row_index = row_index
return bmatrix._getvalue()
@unbox(BMatrixType)
def unbox_bmatrix(typ, obj, c):
"""
Convert a BMatrixType object to a native interval structure.
"""
m_obj = c.pyapi.object_getattr_string(obj, "m")
n_obj = c.pyapi.object_getattr_string(obj, "n")
row_index_obj = c.pyapi.object_getattr_string(obj, "row_index")
BMatrix = cgutils.create_struct_proxy(typ)(c.context, c.builder)
BMatrix.m = c.pyapi.long_as_longlong(m_obj)
BMatrix.n = c.pyapi.long_as_longlong(n_obj)
BMatrix.row_index = nb.targets.boxing.unbox_array(types.Array(types.int64, 1, 'C'),
row_index_obj, c)
c.pyapi.decref(m_obj)
c.pyapi.decref(n_obj)
c.pyapi.decref(row_index_obj)
is_error = cgutils.is_not_null(c.builder, c.pyapi.err_occurred())
return NativeValue(BMatrix._getvalue(), is_error=is_error)
@box(BMatrixType)
def box_bmatrix(typ, val, c):
"""
Convert a native bmatrix structure to an BMatrix object.
"""
Bmatrix = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val)
m_obj = c.pyapi.long_from_longlong(Bmatrix.m)
n_obj = c.pyapi.long_from_longlong(Bmatrix.n)
row_index_obj = nb.targets.boxing.box_array(types.Array(types.int64, 1, 'C'),
Bmatrix.row_index, c)
class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(Bmatrix))
res = c.pyapi.call_function_objargs(class_obj, (m_obj, n_obj))
c.pyapi.decref(m_obj)
c.pyapi.decref(n_obj)
c.pyapi.decref(row_index_obj)
c.pyapi.decref(class_obj)
return res
Тестовые случаи (трассировки ошибок абсолютно массивны для test_2 и test_3).
@nb.jit(nopython=True)
def test_1(): #Runs
x = BMatrix(10, 10, np.array([10,10,10]))
def test_2(): #Errors
x = BMatrix(10, 10, np.array([10,10,10]))
@nb.jit(nopython=True)
def _test_2(y):
return y
return _test_2(x)
@nb.jit(nopython=True)
def test_3(): #Errors
return BMatrix(10, 10, np.array([10,10,10]))
@nb.jit(nopython=True)
def test_4():
return BMatrix(10, 10, np.array([10,10,10])).row_index
Это ошибки при запуске тестовых случаев
test_1() #Runs
test_2()
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-52-0f6d1bdba40b> in <module>
----> 1 test_2()
<ipython-input-51-60141c9792c1> in test_2()
9 return y
10
---> 11 return _test_2(x)
12 @nb.jit(nopython=True)
13 def test_3():
//anaconda3/lib/python3.7/site-packages/numba/dispatcher.py in _compile_for_args(self, *args, **kws)
368 e.patch_message(''.join(e.args) + help_msg)
369 # ignore the FULL_TRACEBACKS config, this needs reporting!
--> 370 raise e
371
372 def inspect_llvm(self, signature=None):
//anaconda3/lib/python3.7/site-packages/numba/dispatcher.py in _compile_for_args(self, *args, **kws)
325 argtypes.append(self.typeof_pyval(a))
326 try:
--> 327 return self.compile(tuple(argtypes))
328 except errors.TypingError as e:
329 # Intercept typing error that may be due to an argument
//anaconda3/lib/python3.7/site-packages/numba/compiler_lock.py in _acquire_compile_lock(*args, **kwargs)
30 def _acquire_compile_lock(*args, **kwargs):
31 with self:
---> 32 return func(*args, **kwargs)
33 return _acquire_compile_lock
34
//anaconda3/lib/python3.7/site-packages/numba/dispatcher.py in compile(self, sig)
657
658 self._cache_misses[sig] += 1
--> 659 cres = self._compiler.compile(args, return_type)
660 self.add_overload(cres)
661 self._cache.save_overload(sig, cres)
//anaconda3/lib/python3.7/site-packages/numba/dispatcher.py in compile(self, args, return_type)
81 args=args, return_type=return_type,
82 flags=flags, locals=self.locals,
---> 83 pipeline_class=self.pipeline_class)
84 # Check typing error if object mode is used
85 if cres.typing_error is not None and not flags.enable_pyobject:
//anaconda3/lib/python3.7/site-packages/numba/compiler.py in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
953 pipeline = pipeline_class(typingctx, targetctx, library,
954 args, return_type, flags, locals)
--> 955 return pipeline.compile_extra(func)
956
957
//anaconda3/lib/python3.7/site-packages/numba/compiler.py in compile_extra(self, func)
375 self.lifted = ()
376 self.lifted_from = None
--> 377 return self._compile_bytecode()
378
379 def compile_ir(self, func_ir, lifted=(), lifted_from=None):
//anaconda3/lib/python3.7/site-packages/numba/compiler.py in _compile_bytecode(self)
884 """
885 assert self.func_ir is None
--> 886 return self._compile_core()
887
888 def _compile_ir(self):
//anaconda3/lib/python3.7/site-packages/numba/compiler.py in _compile_core(self)
871 self.define_pipelines(pm)
872 pm.finalize()
--> 873 res = pm.run(self.status)
874 if res is not None:
875 # Early pipeline completion
//anaconda3/lib/python3.7/site-packages/numba/compiler_lock.py in _acquire_compile_lock(*args, **kwargs)
30 def _acquire_compile_lock(*args, **kwargs):
31 with self:
---> 32 return func(*args, **kwargs)
33 return _acquire_compile_lock
34
//anaconda3/lib/python3.7/site-packages/numba/compiler.py in run(self, status)
252 # No more fallback pipelines?
253 if is_final_pipeline:
--> 254 raise patched_exception
255 # Go to next fallback pipeline
256 else:
//anaconda3/lib/python3.7/site-packages/numba/compiler.py in run(self, status)
243 try:
244 event("-- %s" % stage_name)
--> 245 stage()
246 except _EarlyPipelineCompletion as e:
247 return e.result
//anaconda3/lib/python3.7/site-packages/numba/compiler.py in stage_nopython_backend(self)
745 """
746 lowerfn = self.backend_nopython_mode
--> 747 self._backend(lowerfn, objectmode=False)
748
749 def stage_compile_interp_mode(self):
//anaconda3/lib/python3.7/site-packages/numba/compiler.py in _backend(self, lowerfn, objectmode)
685 self.library.enable_object_caching()
686
--> 687 lowered = lowerfn()
688 signature = typing.signature(self.return_type, *self.args)
689 self.cr = compile_result(
//anaconda3/lib/python3.7/site-packages/numba/compiler.py in backend_nopython_mode(self)
672 self.calltypes,
673 self.flags,
--> 674 self.metadata)
675
676 def _backend(self, lowerfn, objectmode):
//anaconda3/lib/python3.7/site-packages/numba/compiler.py in native_lowering_stage(targetctx, library, interp, typemap, restype, calltypes, flags, metadata)
1124 lower.lower()
1125 if not flags.no_cpython_wrapper:
-> 1126 lower.create_cpython_wrapper(flags.release_gil)
1127 env = lower.env
1128 call_helper = lower.call_helper
//anaconda3/lib/python3.7/site-packages/numba/lowering.py in create_cpython_wrapper(self, release_gil)
269 self.context.create_cpython_wrapper(self.library, self.fndesc,
270 self.env, self.call_helper,
--> 271 release_gil=release_gil)
272
273 def setup_function(self, fndesc):
//anaconda3/lib/python3.7/site-packages/numba/targets/cpu.py in create_cpython_wrapper(self, library, fndesc, env, call_helper, release_gil)
155 fndesc, env, call_helper=call_helper,
156 release_gil=release_gil)
--> 157 builder.build()
158 library.add_ir_module(wrapper_module)
159
//anaconda3/lib/python3.7/site-packages/numba/callwrapper.py in build(self)
120
121 api = self.context.get_python_api(builder)
--> 122 self.build_wrapper(api, builder, closure, args, kws)
123
124 return wrapper, api
//anaconda3/lib/python3.7/site-packages/numba/callwrapper.py in build_wrapper(self, api, builder, closure, args, kws)
153 innerargs.append(None)
154 else:
--> 155 val = cleanup_manager.add_arg(builder.load(obj), ty)
156 innerargs.append(val)
157
//anaconda3/lib/python3.7/site-packages/numba/callwrapper.py in add_arg(self, obj, ty)
30 """
31 # Unbox argument
---> 32 native = self.api.to_native_value(ty, obj)
33
34 # If an error occurred, go to the cleanup block for the previous argument.
//anaconda3/lib/python3.7/site-packages/numba/pythonapi.py in to_native_value(self, typ, obj)
1423 impl = _unboxers.lookup(typ.__class__, unbox_unsupported)
1424 c = _UnboxContext(self.context, self.builder, self)
-> 1425 return impl(typ, obj, c)
1426
1427 def from_native_return(self, typ, val, env_manager):
<ipython-input-45-d8ac5afde794> in unbox_bmatrix(typ, obj, c)
85 BMatrix.n = c.pyapi.long_as_longlong(n_obj)
86 BMatrix.row_index = nb.targets.boxing.unbox_array(types.Array(types.int64, 1, 'C'),
---> 87 row_index_obj, c)
88 c.pyapi.decref(m_obj)
89 c.pyapi.decref(n_obj)
//anaconda3/lib/python3.7/site-packages/numba/cgutils.py in __setattr__(self, field, value)
162 if field.startswith('_'):
163 return super(_StructProxy, self).__setattr__(field, value)
--> 164 self[self._datamodel.get_field_position(field)] = value
165
166 def __getitem__(self, index):
//anaconda3/lib/python3.7/site-packages/numba/cgutils.py in __setitem__(self, index, value)
177 ptr = self._get_ptr_by_index(index)
178 value = self._cast_member_from_value(index, value)
--> 179 if value.type != ptr.type.pointee:
180 if (is_pointer(value.type) and is_pointer(ptr.type.pointee)
181 and value.type.pointee == ptr.type.pointee.pointee):
AttributeError: Failed in nopython mode pipeline (step: nopython mode backend)
'NativeValue' object has no attribute 'type'
test_3()
KeyError Traceback (most recent call last)
//anaconda3/lib/python3.7/site-packages/numba/pythonapi.py in serialize_object(self, obj)
1403 try:
-> 1404 gv = self.module.__serialized[obj]
1405 except KeyError:
KeyError: <numba.cgutils.ValueStructProxy_BMatrix object at 0x11e693f28>
During handling of the above exception, another exception occurred:
PicklingError Traceback (most recent call last)
<ipython-input-53-8d78c7c0acee> in <module>
----> 1 test_3()
//anaconda3/lib/python3.7/site-packages/numba/dispatcher.py in _compile_for_args(self, *args, **kws)
368 e.patch_message(''.join(e.args) + help_msg)
369 # ignore the FULL_TRACEBACKS config, this needs reporting!
--> 370 raise e
371
372 def inspect_llvm(self, signature=None):
//anaconda3/lib/python3.7/site-packages/numba/dispatcher.py in _compile_for_args(self, *args, **kws)
325 argtypes.append(self.typeof_pyval(a))
326 try:
--> 327 return self.compile(tuple(argtypes))
328 except errors.TypingError as e:
329 # Intercept typing error that may be due to an argument
//anaconda3/lib/python3.7/site-packages/numba/compiler_lock.py in _acquire_compile_lock(*args, **kwargs)
30 def _acquire_compile_lock(*args, **kwargs):
31 with self:
---> 32 return func(*args, **kwargs)
33 return _acquire_compile_lock
34
//anaconda3/lib/python3.7/site-packages/numba/dispatcher.py in compile(self, sig)
657
658 self._cache_misses[sig] += 1
--> 659 cres = self._compiler.compile(args, return_type)
660 self.add_overload(cres)
661 self._cache.save_overload(sig, cres)
//anaconda3/lib/python3.7/site-packages/numba/dispatcher.py in compile(self, args, return_type)
81 args=args, return_type=return_type,
82 flags=flags, locals=self.locals,
---> 83 pipeline_class=self.pipeline_class)
84 # Check typing error if object mode is used
85 if cres.typing_error is not None and not flags.enable_pyobject:
//anaconda3/lib/python3.7/site-packages/numba/compiler.py in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
953 pipeline = pipeline_class(typingctx, targetctx, library,
954 args, return_type, flags, locals)
--> 955 return pipeline.compile_extra(func)
956
957
//anaconda3/lib/python3.7/site-packages/numba/compiler.py in compile_extra(self, func)
375 self.lifted = ()
376 self.lifted_from = None
--> 377 return self._compile_bytecode()
378
379 def compile_ir(self, func_ir, lifted=(), lifted_from=None):
//anaconda3/lib/python3.7/site-packages/numba/compiler.py in _compile_bytecode(self)
884 """
885 assert self.func_ir is None
--> 886 return self._compile_core()
887
888 def _compile_ir(self):
//anaconda3/lib/python3.7/site-packages/numba/compiler.py in _compile_core(self)
871 self.define_pipelines(pm)
872 pm.finalize()
--> 873 res = pm.run(self.status)
874 if res is not None:
875 # Early pipeline completion
//anaconda3/lib/python3.7/site-packages/numba/compiler_lock.py in _acquire_compile_lock(*args, **kwargs)
30 def _acquire_compile_lock(*args, **kwargs):
31 with self:
---> 32 return func(*args, **kwargs)
33 return _acquire_compile_lock
34
//anaconda3/lib/python3.7/site-packages/numba/compiler.py in run(self, status)
252 # No more fallback pipelines?
253 if is_final_pipeline:
--> 254 raise patched_exception
255 # Go to next fallback pipeline
256 else:
//anaconda3/lib/python3.7/site-packages/numba/compiler.py in run(self, status)
243 try:
244 event("-- %s" % stage_name)
--> 245 stage()
246 except _EarlyPipelineCompletion as e:
247 return e.result
//anaconda3/lib/python3.7/site-packages/numba/compiler.py in stage_nopython_backend(self)
745 """
746 lowerfn = self.backend_nopython_mode
--> 747 self._backend(lowerfn, objectmode=False)
748
749 def stage_compile_interp_mode(self):
//anaconda3/lib/python3.7/site-packages/numba/compiler.py in _backend(self, lowerfn, objectmode)
685 self.library.enable_object_caching()
686
--> 687 lowered = lowerfn()
688 signature = typing.signature(self.return_type, *self.args)
689 self.cr = compile_result(
//anaconda3/lib/python3.7/site-packages/numba/compiler.py in backend_nopython_mode(self)
672 self.calltypes,
673 self.flags,
--> 674 self.metadata)
675
676 def _backend(self, lowerfn, objectmode):
//anaconda3/lib/python3.7/site-packages/numba/compiler.py in native_lowering_stage(targetctx, library, interp, typemap, restype, calltypes, flags, metadata)
1124 lower.lower()
1125 if not flags.no_cpython_wrapper:
-> 1126 lower.create_cpython_wrapper(flags.release_gil)
1127 env = lower.env
1128 call_helper = lower.call_helper
//anaconda3/lib/python3.7/site-packages/numba/lowering.py in create_cpython_wrapper(self, release_gil)
269 self.context.create_cpython_wrapper(self.library, self.fndesc,
270 self.env, self.call_helper,
--> 271 release_gil=release_gil)
272
273 def setup_function(self, fndesc):
//anaconda3/lib/python3.7/site-packages/numba/targets/cpu.py in create_cpython_wrapper(self, library, fndesc, env, call_helper, release_gil)
155 fndesc, env, call_helper=call_helper,
156 release_gil=release_gil)
--> 157 builder.build()
158 library.add_ir_module(wrapper_module)
159
//anaconda3/lib/python3.7/site-packages/numba/callwrapper.py in build(self)
120
121 api = self.context.get_python_api(builder)
--> 122 self.build_wrapper(api, builder, closure, args, kws)
123
124 return wrapper, api
//anaconda3/lib/python3.7/site-packages/numba/callwrapper.py in build_wrapper(self, api, builder, closure, args, kws)
174
175 retty = self._simplified_return_type()
--> 176 obj = api.from_native_return(retty, retval, env_manager)
177 builder.ret(obj)
178
//anaconda3/lib/python3.7/site-packages/numba/pythonapi.py in from_native_return(self, typ, val, env_manager)
1429 "prevented the return of " \
1430 "optional value"
-> 1431 out = self.from_native_value(typ, val, env_manager)
1432 return out
1433
//anaconda3/lib/python3.7/site-packages/numba/pythonapi.py in from_native_value(self, typ, val, env_manager)
1443
1444 c = _BoxContext(self.context, self.builder, self, env_manager)
-> 1445 return impl(typ, val, c)
1446
1447 def reflect_native_value(self, typ, val, env_manager=None):
<ipython-input-45-d8ac5afde794> in box_bmatrix(typ, val, c)
104 Bmatrix.row_index, c)
105
--> 106 class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(Bmatrix))
107 res = c.pyapi.call_function_objargs(class_obj, (m_obj, n_obj))
108 c.pyapi.decref(m_obj)
//anaconda3/lib/python3.7/site-packages/numba/pythonapi.py in serialize_object(self, obj)
1404 gv = self.module.__serialized[obj]
1405 except KeyError:
-> 1406 struct = self.serialize_uncached(obj)
1407 name = ".const.picklebuf.%s" % (id(obj) if config.DIFF_IR == 0 else "DIFF_IR")
1408 gv = self.context.insert_unique_const(self.module, name, struct)
//anaconda3/lib/python3.7/site-packages/numba/pythonapi.py in serialize_uncached(self, obj)
1383 """
1384 # First make the array constant
-> 1385 data = pickle.dumps(obj, protocol=-1)
1386 assert len(data) < 2**31
1387 name = ".const.pickledata.%s" % (id(obj) if config.DIFF_IR == 0 else "DIFF_IR")
PicklingError: Failed in nopython mode pipeline (step: nopython mode backend)
Can't pickle <class 'numba.cgutils.ValueStructProxy_BMatrix'>: attribute lookup ValueStructProxy_BMatrix on numba.cgutils failed
test_4() #Runs Wrong
array([-2387225703656530210, -2387225703656530210, -2387225703656530210])