Как упаковывать и распаковывать массивы NumPy в Numba - PullRequest
0 голосов
/ 23 февраля 2020

Я пытаюсь определить пользовательский класс для использования с Numba. Я хотел бы, чтобы класс мог хранить массивы NumPy. Я следовал примеру Interval из документации Numba: https://numba.pydata.org/numba-doc/latest/extending/interval-example.html

Моя проблема в том, что я не могу найти никакой документации о том, как упаковывать и распаковывать массивы NumPy. Я сделал это до сих пор, но у меня нет идей, как упаковывать и распаковывать массив. Я предполагаю, что мне просто нужно заполнить правильный метод в вопросительных знаках. Любая помощь будет принята с благодарностью.

class Scenario(object):
    __slots__ = 'a'

    def __new__(cls, a):
        self = object.__new__(cls)
        self.a = a
        return self


class ScenarioType(types.Type):
    def __init__(self):
        super(ScenarioType, self).__init__(name='Scenario')


scenario_type = ScenarioType()


@typeof_impl.register(Scenario)
def typeof_index(val, c):
    return scenario_type


@type_callable(Scenario)
def type_interval(context):
    def typer(a):
        if isinstance(a, types.Array):
            return scenario_type
    return typer


@register_model(ScenarioType)
class ScenarioModel(models.StructModel):
    def __init__(self, dmm, fe_type):
        members = [
            ('a', types.float64[:]),
        ]
        models.StructModel.__init__(self, dmm, fe_type, members)


make_attribute_wrapper(ScenarioType, 'a', 'a')


@lower_builtin(Scenario, types.Array)
def impl_scenario(context, builder, sig, args):
    typ = sig.return_type
    a = args
    scenario = cgutils.create_struct_proxy(typ)(context, builder)
    scenario.a = a
    return scenario._getvalue()


@unbox(ScenarioType)
def unbox_scenario(typ, obj, c):
    a_obj = c.pyapi.object_getattr_string(obj, "a")
    scenario = cgutils.create_struct_proxy(typ)(c.context, c.builder)
    scenario.a = c.pyapi.???? # What to do here?
    c.pyapi.decref(a_obj)
    is_error = cgutils.is_not_null(c.builder, c.pyapi.err_occurred())
    return NativeValue(scenario._getvalue(), is_error=is_error)


@box(ScenarioType)
def box_scenario(typ, val, c):
    scenario = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val)
    a_obj = c.pyapi.???? # What to do here?
    class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(Scenario))
    res = c.pyapi.call_function_objargs(class_obj, (a_obj,))
    c.pyapi.decref(a_obj)
    c.pyapi.decref(class_obj)
    return res


@jit(nopython=True)
def pass_scenario(my_scen):
    res = my_scen
    return res


scen_instance = Scenario(np.array([2.0, 3.0])

print(pass_scenario(scen_instance.a))
...