Как передать правильный тип данных в класс с помощью Numba? - PullRequest
0 голосов
/ 28 февраля 2020

Я пытаюсь использовать numba , чтобы мой код выполнялся быстрее. Однако код выдает следующую ошибку:

This error may have been caused by the following argument(s):
- argument 0: Unsupported array dtype: object
- argument 1: Unsupported array dtype: object

У меня есть класс, написанный каким-то образом:

spec = [
    ('train_x', float64[:,:]),
    ('train_y', float64[:]),
    ('test_x', float64[:,:]),
    ('test_y', float64[:]),
]

@jitclass(spec)
class num_features:
    def __init__(self, train_x,  test_x, train_y, test_y):
        self.train_x, self.train_y = train_x, train_y
        self.test_x, self.test_y = test_x, test_y
        self.X_train, self.Y_train = [] , []
        self.X_test, self.Y_test = [] , []

    @property
    def extract_stats(self, matrix):
    ...

Я вызываю класс как

obj = num_features(train_x.to_numpy(), test_x.to_numpy(), train_y, test_y)

train_x и test_x pandas датафрейм.

1 Ответ

2 голосов
/ 28 февраля 2020

В вашем коде несколько ошибок. Во-первых, вы не можете использовать обычные списки python в numba-классе, все атрибуты должны быть набраны. Вам нужно будет указать оба атрибута как ListTypes и назначить им тип, который они будут содержать, например float64.

Во-вторых, настоящая ошибка, которую вы видите, заключается в том, что вы пытаетесь передать вход train_x и test_x как numpy массивы, которые содержат данные, НЕ являющиеся float64. Вот что говорит вам ошибка "Unsupported array dtype: object": ваши массивы для аргумента 0 и аргумента 1 являются массивами объектов или массивами python объектов.

Когда вы конвертируете их в numpy массивы, передайте dtype.

Кроме того, не увлекайтесь назначениями кортежей, numba достаточно привередлив, просто ставьте их по одному в строке .

from numba import jitclass, float64, typed, types

spec = [
    ('train_x', float64[:,:]),
    ('train_y', float64[:]),
    ('test_x', float64[:,:]),
    ('test_y', float64[:]),
    ('X_train', types.ListType(types.float64)),
    ('Y_train', types.ListType(types.float64)),
    ('X_test', types.ListType(types.float64)),
    ('Y_test', types.ListType(types.float64))
]

@jitclass(spec)
class num_features:
    def __init__(self, train_x,  test_x, train_y, test_y):
        self.train_x = train_x
        self.train_y = train_y
        self.test_x = test_x
        self.test_y = test_y
        self.X_train = typed.List.empty_list(types.float64)
        self.Y_train = typed.List.empty_list(types.float64)
        self.X_test = typed.List.empty_list(types.float64)
        self.Y_test = typed.List.empty_list(types.float64)

    @property
    def extract_stats(self, matrix):
    ...

Теперь, чтобы фактически вызвать класс, вам нужно передать массивы float64. Вы можете использовать:

obj = num_features(train_x.to_numpy(np.float64),
                   test_x.to_numpy(np.float64),
                   train_y.astype(np.float64),
                   test_y.astype(np.float64))
...