Как переопределить метод getitem с индексированием на основе numpy ndarray? - PullRequest
1 голос
/ 20 января 2020

Я пытаюсь создать класс, основанный на numpy .ndarray, методы __getitem__ и __setitem__ которого предоставляют данные, основанные на индексации, например: Point class:

import numpy as np

class Point:
    def __init__(self, number=20):
        dt = np.dtype([("x",np.float64), ("y",np.float64), ("alive",np.bool)])
        self.points = np.zeros((int(number),1), dtype=dt)
        self.points["alive"] = True

#    def __getitem__(self, i):
#        mask = self.points["alive"] == True
#        print("get")
#        return self.points[mask].__getitem__(i)

    def __getitem__(self, i):
        mask = self.points["alive"] == True
        print("get")
        return self.points[mask][i]

    def __setitem__(self, i, item):
        mask = self.points["alive"] == True
        print("set")
        self.points[mask][i] = item

И если я try:

p = Point()
print(p[0])
>>>>get
>>>>(0., 0., True)
print(p[0]["alive"])
>>>>get
>>>>True
p[0]["alive"] = False
>>>>get
print(p.points[0]["alive"])
>>>>[ True]

Таким образом, изменение не учитывается, но я не получил ошибку, как будто я изменял копию. Также я запутался, потому что я не вызываю метод __setitem__, а метод __getitem__. Я попробовал другую реализацию, используя __getitem__ из ndarray, но есть та же проблема.

Что я делаю не так и как это сделать правильно?

1 Ответ

0 голосов
/ 14 февраля 2020

Я нашел, как это сделать, используя pandas, что позволяет делать это без цепной индексации и получать представление массива вместо копии:

import numpy as np
import pandas as pd

class Point:
    def __init__(self, number=20):
        d = {"x":np.zeros((int(number),)), "y":np.zeros((int(number),))}
        self.points = pd.DataFrame(d)
        self.alive = pd.Series(np.ones((int(number),),dtype=bool))

    def __getitem__(self, i):
        return self.points.loc[self.alive,i]

    def __setitem__(self, i, item):
        self.points.loc[self.alive,i] = item

Что дает правильное поведение:

p = Point(3)
print(p[:])

>>>>     x    y
>>>> 0  0.0  0.0
>>>> 1  0.0  0.0
>>>> 2  0.0  0.0

p.alive[0] = False
print(p[:])

>>>>     x    y
>>>> 1  0.0  0.0
>>>> 2  0.0  0.0

p["x"] = p["x"] + 5
print(p[:])

>>>>     x    y
>>>> 1  5.0  0.0
>>>> 2  5.0  0.0

...