Python / Numba: проблема с созданием пользовательского типа с помощью Numba Extension API - PullRequest
2 голосов
/ 25 октября 2019

Я пытаюсь создать пользовательский тип Numba. У меня проблемы с упаковкой и распаковкой Numba Numpy Arrays в Native Numpy Arrays.

Я искал в Интернете похожие проблемы и следовал примеру документации в меру своих возможностей. (https://numba.pydata.org/numba-doc/latest/extending/interval-example.html).

Я пытался интерпретировать (https://github.com/numba/numba/blob/master/numba/targets/boxing.py), но это довольно сложно. Поэтому, я думаю, я мог бы сделать что-то маленькое неправильно.

Нижемоя текущая попытка включения массива Numpy в мой пользовательский тип.

import numpy as np
from numba import types, cgutils
from numba.extending import typeof_impl, type_callable, models
from numba.extending import register_model, make_attribute_wrapper, overload_attribute
from numba.extending import lower_builtin, unbox, NativeValue, box
class BMatrix(object):
    """
    A empty wrapper for a Binary Matrix
    """
    def __init__(self, m, n, row_index):#, col_index):
        self.m = m
        self.n = n
        self.row_index = row_index
        # self.col_i = col_index

    def __repr__(self):
        return 'BMatrix(%d, %d)' % (self.m, self.n)

    @property
    def shape(self):
        return (self.m, self.n)

class BMatrixType(types.Type):
    def __init__(self):
        super(BMatrixType, self).__init__(name='BMatrix')

bmatrix_type = BMatrixType()

@typeof_impl.register(BMatrix)
def typeof_index(val, c):
    return bmatrix_type


@type_callable(BMatrix)
def type_bmatrix(context):
    def typer(m, n, row_index):
        if (isinstance(m, types.Integer) 
            and isinstance(n, types.Integer) 
            and isinstance(row_index, nb.types.Array)):
            # and isinstance(col_index, nb.types.Array)):
            return bmatrix_type
    return typer


@register_model(BMatrixType)
class BMatrixModel(models.StructModel):
    def __init__(self, dmm, fe_type):
        members = [
            ('m', types.int64),
            ('n', types.int64),
            ('row_index', types.Array(types.int64, 1, 'C'))
            ]
        models.StructModel.__init__(self, dmm, fe_type, members)


make_attribute_wrapper(BMatrixType, 'm', 'm')
make_attribute_wrapper(BMatrixType, 'n', 'n')
make_attribute_wrapper(BMatrixType, 'row_index', 'row_index')


@overload_attribute(BMatrixType, "shape")
def get_shape(bmatrix):
    def getter(bmatrix):
        return (bmatrix.m, bmatrix.n)
    return getter


@lower_builtin(BMatrix, types.Integer, types.Integer, types.Array) #nb.types.Array, #nb.types.Array)
def impl_bmatrix(context, builder, sig, args):
    typ = sig.return_type
    m, n, row_index = args
    bmatrix = cgutils.create_struct_proxy(typ)(context, builder)
    bmatrix.m = m
    bmatrix.n = n
    bmatrix.row_index = row_index

    return bmatrix._getvalue()


@unbox(BMatrixType)
def unbox_bmatrix(typ, obj, c):
    """
    Convert a BMatrixType object to a native interval structure.
    """
    m_obj = c.pyapi.object_getattr_string(obj, "m")
    n_obj = c.pyapi.object_getattr_string(obj, "n")
    row_index_obj = c.pyapi.object_getattr_string(obj, "row_index")
    BMatrix = cgutils.create_struct_proxy(typ)(c.context, c.builder)
    BMatrix.m = c.pyapi.long_as_longlong(m_obj)
    BMatrix.n = c.pyapi.long_as_longlong(n_obj)
    BMatrix.row_index = nb.targets.boxing.unbox_array(types.Array(types.int64, 1, 'C'),
                                                      row_index_obj, c)
    c.pyapi.decref(m_obj)
    c.pyapi.decref(n_obj)
    c.pyapi.decref(row_index_obj)
    is_error = cgutils.is_not_null(c.builder, c.pyapi.err_occurred())
    return NativeValue(BMatrix._getvalue(), is_error=is_error)


@box(BMatrixType)
def box_bmatrix(typ, val, c):
    """
    Convert a native bmatrix structure to an BMatrix object.
    """
    Bmatrix = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val)
    m_obj = c.pyapi.long_from_longlong(Bmatrix.m)
    n_obj = c.pyapi.long_from_longlong(Bmatrix.n)      
    row_index_obj = nb.targets.boxing.box_array(types.Array(types.int64, 1, 'C'),
                                                Bmatrix.row_index, c)

    class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(Bmatrix))
    res = c.pyapi.call_function_objargs(class_obj, (m_obj, n_obj))
    c.pyapi.decref(m_obj)
    c.pyapi.decref(n_obj)
    c.pyapi.decref(row_index_obj)
    c.pyapi.decref(class_obj)
    return res

Тестовые случаи (трассировки ошибок абсолютно массивны для test_2 и test_3).

@nb.jit(nopython=True)
def test_1(): #Runs
    x = BMatrix(10, 10, np.array([10,10,10]))

def test_2(): #Errors 
    x = BMatrix(10, 10, np.array([10,10,10]))
    @nb.jit(nopython=True)
    def _test_2(y):
        return y

    return _test_2(x)


@nb.jit(nopython=True)
def test_3(): #Errors
    return BMatrix(10, 10, np.array([10,10,10]))

@nb.jit(nopython=True)
def test_4():
    return BMatrix(10, 10, np.array([10,10,10])).row_index

Это ошибки при запуске тестовых случаев

test_1() #Runs
test_2() 
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-52-0f6d1bdba40b> in <module>
----> 1 test_2()

<ipython-input-51-60141c9792c1> in test_2()
      9         return y
     10 
---> 11     return _test_2(x)
     12 @nb.jit(nopython=True)
     13 def test_3():

//anaconda3/lib/python3.7/site-packages/numba/dispatcher.py in _compile_for_args(self, *args, **kws)
    368                     e.patch_message(''.join(e.args) + help_msg)
    369             # ignore the FULL_TRACEBACKS config, this needs reporting!
--> 370             raise e
    371 
    372     def inspect_llvm(self, signature=None):

//anaconda3/lib/python3.7/site-packages/numba/dispatcher.py in _compile_for_args(self, *args, **kws)
    325                 argtypes.append(self.typeof_pyval(a))
    326         try:
--> 327             return self.compile(tuple(argtypes))
    328         except errors.TypingError as e:
    329             # Intercept typing error that may be due to an argument

//anaconda3/lib/python3.7/site-packages/numba/compiler_lock.py in _acquire_compile_lock(*args, **kwargs)
     30         def _acquire_compile_lock(*args, **kwargs):
     31             with self:
---> 32                 return func(*args, **kwargs)
     33         return _acquire_compile_lock
     34 

//anaconda3/lib/python3.7/site-packages/numba/dispatcher.py in compile(self, sig)
    657 
    658             self._cache_misses[sig] += 1
--> 659             cres = self._compiler.compile(args, return_type)
    660             self.add_overload(cres)
    661             self._cache.save_overload(sig, cres)

//anaconda3/lib/python3.7/site-packages/numba/dispatcher.py in compile(self, args, return_type)
     81                                       args=args, return_type=return_type,
     82                                       flags=flags, locals=self.locals,
---> 83                                       pipeline_class=self.pipeline_class)
     84         # Check typing error if object mode is used
     85         if cres.typing_error is not None and not flags.enable_pyobject:

//anaconda3/lib/python3.7/site-packages/numba/compiler.py in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
    953     pipeline = pipeline_class(typingctx, targetctx, library,
    954                               args, return_type, flags, locals)
--> 955     return pipeline.compile_extra(func)
    956 
    957 

//anaconda3/lib/python3.7/site-packages/numba/compiler.py in compile_extra(self, func)
    375         self.lifted = ()
    376         self.lifted_from = None
--> 377         return self._compile_bytecode()
    378 
    379     def compile_ir(self, func_ir, lifted=(), lifted_from=None):

//anaconda3/lib/python3.7/site-packages/numba/compiler.py in _compile_bytecode(self)
    884         """
    885         assert self.func_ir is None
--> 886         return self._compile_core()
    887 
    888     def _compile_ir(self):

//anaconda3/lib/python3.7/site-packages/numba/compiler.py in _compile_core(self)
    871         self.define_pipelines(pm)
    872         pm.finalize()
--> 873         res = pm.run(self.status)
    874         if res is not None:
    875             # Early pipeline completion

//anaconda3/lib/python3.7/site-packages/numba/compiler_lock.py in _acquire_compile_lock(*args, **kwargs)
     30         def _acquire_compile_lock(*args, **kwargs):
     31             with self:
---> 32                 return func(*args, **kwargs)
     33         return _acquire_compile_lock
     34 

//anaconda3/lib/python3.7/site-packages/numba/compiler.py in run(self, status)
    252                     # No more fallback pipelines?
    253                     if is_final_pipeline:
--> 254                         raise patched_exception
    255                     # Go to next fallback pipeline
    256                     else:

//anaconda3/lib/python3.7/site-packages/numba/compiler.py in run(self, status)
    243                 try:
    244                     event("-- %s" % stage_name)
--> 245                     stage()
    246                 except _EarlyPipelineCompletion as e:
    247                     return e.result

//anaconda3/lib/python3.7/site-packages/numba/compiler.py in stage_nopython_backend(self)
    745         """
    746         lowerfn = self.backend_nopython_mode
--> 747         self._backend(lowerfn, objectmode=False)
    748 
    749     def stage_compile_interp_mode(self):

//anaconda3/lib/python3.7/site-packages/numba/compiler.py in _backend(self, lowerfn, objectmode)
    685             self.library.enable_object_caching()
    686 
--> 687         lowered = lowerfn()
    688         signature = typing.signature(self.return_type, *self.args)
    689         self.cr = compile_result(

//anaconda3/lib/python3.7/site-packages/numba/compiler.py in backend_nopython_mode(self)
    672                 self.calltypes,
    673                 self.flags,
--> 674                 self.metadata)
    675 
    676     def _backend(self, lowerfn, objectmode):

//anaconda3/lib/python3.7/site-packages/numba/compiler.py in native_lowering_stage(targetctx, library, interp, typemap, restype, calltypes, flags, metadata)
   1124         lower.lower()
   1125         if not flags.no_cpython_wrapper:
-> 1126             lower.create_cpython_wrapper(flags.release_gil)
   1127         env = lower.env
   1128         call_helper = lower.call_helper

//anaconda3/lib/python3.7/site-packages/numba/lowering.py in create_cpython_wrapper(self, release_gil)
    269         self.context.create_cpython_wrapper(self.library, self.fndesc,
    270                                             self.env, self.call_helper,
--> 271                                             release_gil=release_gil)
    272 
    273     def setup_function(self, fndesc):

//anaconda3/lib/python3.7/site-packages/numba/targets/cpu.py in create_cpython_wrapper(self, library, fndesc, env, call_helper, release_gil)
    155                                 fndesc, env, call_helper=call_helper,
    156                                 release_gil=release_gil)
--> 157         builder.build()
    158         library.add_ir_module(wrapper_module)
    159 

//anaconda3/lib/python3.7/site-packages/numba/callwrapper.py in build(self)
    120 
    121         api = self.context.get_python_api(builder)
--> 122         self.build_wrapper(api, builder, closure, args, kws)
    123 
    124         return wrapper, api

//anaconda3/lib/python3.7/site-packages/numba/callwrapper.py in build_wrapper(self, api, builder, closure, args, kws)
    153                 innerargs.append(None)
    154             else:
--> 155                 val = cleanup_manager.add_arg(builder.load(obj), ty)
    156                 innerargs.append(val)
    157 

//anaconda3/lib/python3.7/site-packages/numba/callwrapper.py in add_arg(self, obj, ty)
     30         """
     31         # Unbox argument
---> 32         native = self.api.to_native_value(ty, obj)
     33 
     34         # If an error occurred, go to the cleanup block for the previous argument.

//anaconda3/lib/python3.7/site-packages/numba/pythonapi.py in to_native_value(self, typ, obj)
   1423         impl = _unboxers.lookup(typ.__class__, unbox_unsupported)
   1424         c = _UnboxContext(self.context, self.builder, self)
-> 1425         return impl(typ, obj, c)
   1426 
   1427     def from_native_return(self, typ, val, env_manager):

<ipython-input-45-d8ac5afde794> in unbox_bmatrix(typ, obj, c)
     85     BMatrix.n = c.pyapi.long_as_longlong(n_obj)
     86     BMatrix.row_index = nb.targets.boxing.unbox_array(types.Array(types.int64, 1, 'C'),
---> 87                                                       row_index_obj, c)
     88     c.pyapi.decref(m_obj)
     89     c.pyapi.decref(n_obj)

//anaconda3/lib/python3.7/site-packages/numba/cgutils.py in __setattr__(self, field, value)
    162         if field.startswith('_'):
    163             return super(_StructProxy, self).__setattr__(field, value)
--> 164         self[self._datamodel.get_field_position(field)] = value
    165 
    166     def __getitem__(self, index):

//anaconda3/lib/python3.7/site-packages/numba/cgutils.py in __setitem__(self, index, value)
    177         ptr = self._get_ptr_by_index(index)
    178         value = self._cast_member_from_value(index, value)
--> 179         if value.type != ptr.type.pointee:
    180             if (is_pointer(value.type) and is_pointer(ptr.type.pointee)
    181                     and value.type.pointee == ptr.type.pointee.pointee):

AttributeError: Failed in nopython mode pipeline (step: nopython mode backend)
'NativeValue' object has no attribute 'type'
test_3()
KeyError                                  Traceback (most recent call last)
//anaconda3/lib/python3.7/site-packages/numba/pythonapi.py in serialize_object(self, obj)
   1403         try:
-> 1404             gv = self.module.__serialized[obj]
   1405         except KeyError:

KeyError: <numba.cgutils.ValueStructProxy_BMatrix object at 0x11e693f28>

During handling of the above exception, another exception occurred:

PicklingError                             Traceback (most recent call last)
<ipython-input-53-8d78c7c0acee> in <module>
----> 1 test_3()

//anaconda3/lib/python3.7/site-packages/numba/dispatcher.py in _compile_for_args(self, *args, **kws)
    368                     e.patch_message(''.join(e.args) + help_msg)
    369             # ignore the FULL_TRACEBACKS config, this needs reporting!
--> 370             raise e
    371 
    372     def inspect_llvm(self, signature=None):

//anaconda3/lib/python3.7/site-packages/numba/dispatcher.py in _compile_for_args(self, *args, **kws)
    325                 argtypes.append(self.typeof_pyval(a))
    326         try:
--> 327             return self.compile(tuple(argtypes))
    328         except errors.TypingError as e:
    329             # Intercept typing error that may be due to an argument

//anaconda3/lib/python3.7/site-packages/numba/compiler_lock.py in _acquire_compile_lock(*args, **kwargs)
     30         def _acquire_compile_lock(*args, **kwargs):
     31             with self:
---> 32                 return func(*args, **kwargs)
     33         return _acquire_compile_lock
     34 

//anaconda3/lib/python3.7/site-packages/numba/dispatcher.py in compile(self, sig)
    657 
    658             self._cache_misses[sig] += 1
--> 659             cres = self._compiler.compile(args, return_type)
    660             self.add_overload(cres)
    661             self._cache.save_overload(sig, cres)

//anaconda3/lib/python3.7/site-packages/numba/dispatcher.py in compile(self, args, return_type)
     81                                       args=args, return_type=return_type,
     82                                       flags=flags, locals=self.locals,
---> 83                                       pipeline_class=self.pipeline_class)
     84         # Check typing error if object mode is used
     85         if cres.typing_error is not None and not flags.enable_pyobject:

//anaconda3/lib/python3.7/site-packages/numba/compiler.py in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
    953     pipeline = pipeline_class(typingctx, targetctx, library,
    954                               args, return_type, flags, locals)
--> 955     return pipeline.compile_extra(func)
    956 
    957 

//anaconda3/lib/python3.7/site-packages/numba/compiler.py in compile_extra(self, func)
    375         self.lifted = ()
    376         self.lifted_from = None
--> 377         return self._compile_bytecode()
    378 
    379     def compile_ir(self, func_ir, lifted=(), lifted_from=None):

//anaconda3/lib/python3.7/site-packages/numba/compiler.py in _compile_bytecode(self)
    884         """
    885         assert self.func_ir is None
--> 886         return self._compile_core()
    887 
    888     def _compile_ir(self):

//anaconda3/lib/python3.7/site-packages/numba/compiler.py in _compile_core(self)
    871         self.define_pipelines(pm)
    872         pm.finalize()
--> 873         res = pm.run(self.status)
    874         if res is not None:
    875             # Early pipeline completion

//anaconda3/lib/python3.7/site-packages/numba/compiler_lock.py in _acquire_compile_lock(*args, **kwargs)
     30         def _acquire_compile_lock(*args, **kwargs):
     31             with self:
---> 32                 return func(*args, **kwargs)
     33         return _acquire_compile_lock
     34 

//anaconda3/lib/python3.7/site-packages/numba/compiler.py in run(self, status)
    252                     # No more fallback pipelines?
    253                     if is_final_pipeline:
--> 254                         raise patched_exception
    255                     # Go to next fallback pipeline
    256                     else:

//anaconda3/lib/python3.7/site-packages/numba/compiler.py in run(self, status)
    243                 try:
    244                     event("-- %s" % stage_name)
--> 245                     stage()
    246                 except _EarlyPipelineCompletion as e:
    247                     return e.result

//anaconda3/lib/python3.7/site-packages/numba/compiler.py in stage_nopython_backend(self)
    745         """
    746         lowerfn = self.backend_nopython_mode
--> 747         self._backend(lowerfn, objectmode=False)
    748 
    749     def stage_compile_interp_mode(self):

//anaconda3/lib/python3.7/site-packages/numba/compiler.py in _backend(self, lowerfn, objectmode)
    685             self.library.enable_object_caching()
    686 
--> 687         lowered = lowerfn()
    688         signature = typing.signature(self.return_type, *self.args)
    689         self.cr = compile_result(

//anaconda3/lib/python3.7/site-packages/numba/compiler.py in backend_nopython_mode(self)
    672                 self.calltypes,
    673                 self.flags,
--> 674                 self.metadata)
    675 
    676     def _backend(self, lowerfn, objectmode):

//anaconda3/lib/python3.7/site-packages/numba/compiler.py in native_lowering_stage(targetctx, library, interp, typemap, restype, calltypes, flags, metadata)
   1124         lower.lower()
   1125         if not flags.no_cpython_wrapper:
-> 1126             lower.create_cpython_wrapper(flags.release_gil)
   1127         env = lower.env
   1128         call_helper = lower.call_helper

//anaconda3/lib/python3.7/site-packages/numba/lowering.py in create_cpython_wrapper(self, release_gil)
    269         self.context.create_cpython_wrapper(self.library, self.fndesc,
    270                                             self.env, self.call_helper,
--> 271                                             release_gil=release_gil)
    272 
    273     def setup_function(self, fndesc):

//anaconda3/lib/python3.7/site-packages/numba/targets/cpu.py in create_cpython_wrapper(self, library, fndesc, env, call_helper, release_gil)
    155                                 fndesc, env, call_helper=call_helper,
    156                                 release_gil=release_gil)
--> 157         builder.build()
    158         library.add_ir_module(wrapper_module)
    159 

//anaconda3/lib/python3.7/site-packages/numba/callwrapper.py in build(self)
    120 
    121         api = self.context.get_python_api(builder)
--> 122         self.build_wrapper(api, builder, closure, args, kws)
    123 
    124         return wrapper, api

//anaconda3/lib/python3.7/site-packages/numba/callwrapper.py in build_wrapper(self, api, builder, closure, args, kws)
    174 
    175             retty = self._simplified_return_type()
--> 176             obj = api.from_native_return(retty, retval, env_manager)
    177             builder.ret(obj)
    178 

//anaconda3/lib/python3.7/site-packages/numba/pythonapi.py in from_native_return(self, typ, val, env_manager)
   1429                                                     "prevented the return of " \
   1430                                                     "optional value"
-> 1431         out = self.from_native_value(typ, val, env_manager)
   1432         return out
   1433 

//anaconda3/lib/python3.7/site-packages/numba/pythonapi.py in from_native_value(self, typ, val, env_manager)
   1443 
   1444         c = _BoxContext(self.context, self.builder, self, env_manager)
-> 1445         return impl(typ, val, c)
   1446 
   1447     def reflect_native_value(self, typ, val, env_manager=None):

<ipython-input-45-d8ac5afde794> in box_bmatrix(typ, val, c)
    104                                                 Bmatrix.row_index, c)
    105 
--> 106     class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(Bmatrix))
    107     res = c.pyapi.call_function_objargs(class_obj, (m_obj, n_obj))
    108     c.pyapi.decref(m_obj)

//anaconda3/lib/python3.7/site-packages/numba/pythonapi.py in serialize_object(self, obj)
   1404             gv = self.module.__serialized[obj]
   1405         except KeyError:
-> 1406             struct = self.serialize_uncached(obj)
   1407             name = ".const.picklebuf.%s" % (id(obj) if config.DIFF_IR == 0 else "DIFF_IR")
   1408             gv = self.context.insert_unique_const(self.module, name, struct)

//anaconda3/lib/python3.7/site-packages/numba/pythonapi.py in serialize_uncached(self, obj)
   1383         """
   1384         # First make the array constant
-> 1385         data = pickle.dumps(obj, protocol=-1)
   1386         assert len(data) < 2**31
   1387         name = ".const.pickledata.%s" % (id(obj) if config.DIFF_IR == 0 else "DIFF_IR")

PicklingError: Failed in nopython mode pipeline (step: nopython mode backend)
Can't pickle <class 'numba.cgutils.ValueStructProxy_BMatrix'>: attribute lookup ValueStructProxy_BMatrix on numba.cgutils failed
test_4() #Runs Wrong
array([-2387225703656530210, -2387225703656530210, -2387225703656530210])
...