Я пытаюсь определить пользовательский класс для использования с Numba. Я хотел бы, чтобы класс мог хранить массивы NumPy. Я следовал примеру Interval из документации Numba: https://numba.pydata.org/numba-doc/latest/extending/interval-example.html
Моя проблема в том, что я не могу найти никакой документации о том, как упаковывать и распаковывать массивы NumPy. Я сделал это до сих пор, но у меня нет идей, как упаковывать и распаковывать массив. Я предполагаю, что мне просто нужно заполнить правильный метод в вопросительных знаках. Любая помощь будет принята с благодарностью.
class Scenario(object):
__slots__ = 'a'
def __new__(cls, a):
self = object.__new__(cls)
self.a = a
return self
class ScenarioType(types.Type):
def __init__(self):
super(ScenarioType, self).__init__(name='Scenario')
scenario_type = ScenarioType()
@typeof_impl.register(Scenario)
def typeof_index(val, c):
return scenario_type
@type_callable(Scenario)
def type_interval(context):
def typer(a):
if isinstance(a, types.Array):
return scenario_type
return typer
@register_model(ScenarioType)
class ScenarioModel(models.StructModel):
def __init__(self, dmm, fe_type):
members = [
('a', types.float64[:]),
]
models.StructModel.__init__(self, dmm, fe_type, members)
make_attribute_wrapper(ScenarioType, 'a', 'a')
@lower_builtin(Scenario, types.Array)
def impl_scenario(context, builder, sig, args):
typ = sig.return_type
a = args
scenario = cgutils.create_struct_proxy(typ)(context, builder)
scenario.a = a
return scenario._getvalue()
@unbox(ScenarioType)
def unbox_scenario(typ, obj, c):
a_obj = c.pyapi.object_getattr_string(obj, "a")
scenario = cgutils.create_struct_proxy(typ)(c.context, c.builder)
scenario.a = c.pyapi.???? # What to do here?
c.pyapi.decref(a_obj)
is_error = cgutils.is_not_null(c.builder, c.pyapi.err_occurred())
return NativeValue(scenario._getvalue(), is_error=is_error)
@box(ScenarioType)
def box_scenario(typ, val, c):
scenario = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val)
a_obj = c.pyapi.???? # What to do here?
class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(Scenario))
res = c.pyapi.call_function_objargs(class_obj, (a_obj,))
c.pyapi.decref(a_obj)
c.pyapi.decref(class_obj)
return res
@jit(nopython=True)
def pass_scenario(my_scen):
res = my_scen
return res
scen_instance = Scenario(np.array([2.0, 3.0])
print(pass_scenario(scen_instance.a))