Как правильно реализовать `__getitem__` для numpy пользовательских подклассов? - PullRequest
0 голосов
/ 06 февраля 2020

Я пытаюсь реализовать numpy подкласс State, который может обращаться к его элементам по атрибутам.

У меня есть дескриптор:

import numpy as np

from collections.abc import Sequence
from typing import Tuple, Union, Type


class _StateDescriptor:
    """ This is a custom descriptor to be used within the State class. It
    provides the possibility of accessing the ith-element of the numpy.ndarray
    by name.

    Parameters
    ---------
    i
        The index of the element in the numpy array to be accessed with

    """

    def __init__(self, i):
        self.i = i

    def __get__(self, obj, objtype=None):
        if obj is None:
            return self
        return obj[self.i]

    def __set__(self, obj, value):
        obj[self.i] = value

, а затем у меня есть class State:

class State(np.ndarray):
    """ :class:`State` is a subclass of :class:`numpy.ndarray`. It behaves like
    a normal :class:`numpy.ndarray` except it can be initialized a bit more
    expressively. In particular each element of the array can be accessed by
    name (in the order variables are provided)

    Attributes
    ----------
    fields
        Each element of the tuple is the name of the corresponding element in
        the array
    """

    fields: Tuple[str]

    def __new__(cls, *args):
        values = args

        cls._set_descriptors(cls)

        arr = np.asarray(list(values), dtype=float).view(cls)

        return arr

    def __getitem__(self, item):
        self._slice_index = item

        return super().__getitem__(item)

    def __array_finalize__(self, obj):
        # Normal __new__ construction
        if obj is None:
            return

        elif self is not None:  # Here we are a view or a new-from-template
            # If we are a new-from-template, need also to slice the fields
            # to the one selected by the slice
            if isinstance(self, State) and (isinstance(obj, State)):
                try:
                    slice_index = obj._slice_index
                except AttributeError:
                    # When the _slice_index attribute is not available,
                    # we're in a strange state, probably numpy performing
                    # strange slicing for operations. We do nothing
                    return

                cls = self.__class__
                if isinstance(slice_index, Sequence):
                    # cls.fields is a tuple and can be indexed only by
                    # integer or slice. With a sequence we manually select
                    # the values specified by the list
                    newfields = []

                    # First we need to get the actual fields that
                    # are requested by the "slice"
                    for i in obj._slice_index:
                        newfields.append(cls.fields[i])

                    newfields = tuple(newfields)
                else:
                    try:
                        newfields = self.fields[slice_index]
                    except AttributeError:
                        pass

                self._sync_descriptors(newfields)
            self._set_descriptors(self)

    @staticmethod
    def _set_descriptors(obj: Union[Type["State"], "State"]):
        """ This method sets the descriptor to each variable of the state
        provided by :class:_StateDescriptor for the class. Can be called
        on the class or on a instance of :class:`State`
        """

        if isinstance(obj, State):
            cls = obj.__class__
            if not (len(obj.fields) == len(obj)):
                raise ValueError(
                    "The length of the array is not equal to the number of "
                    "its fields"
                )
        else:
            cls = obj

        for i, field in enumerate(obj.fields):
            setattr(cls, field, _StateDescriptor(i))

    def _sync_descriptors(self, newfields: Tuple[str]):
        """ This method syncs the descriptors removing unused variables """
        # We check on the full set of variables of the State
        # which ones are requested by the "slice"
        cls = self.__class__
        for old_var in self.fields:
            if not (old_var in newfields):  # Requested
                # We remove the associated attribute
                delattr(cls, old_var)

        self.fields = newfields

Похоже, что это работает для общих операций

QClass = State
QClass.fields = ("rho", "rhoU", "rhoV", "p")
Q = QClass(0, 0, 0, 0)

Q.rhoV = 12
Q_slice = Q[[0, 2]]

# Q_slice.rhoU # Raises AttributeError
# Q_slice.p # Raises AttributeError

assert Q_slice.rho == 0
assert Q_slice.rhoV == 12
Q = QClass(0, 0, 0, 0)
Q.rhoV = 12
Q_slice = Q[:3]

assert Q_slice.rhoV == 12
assert Q_slice.rhoV == Q_slice[-1]

Проблема возникает при выполнении широковещательных операций:

Кажется, что это останавливается работает ... случайно. Если я не использую pdb, исключение не возникает.

arr = np.zeros((5, len(Q.fields)))
Q - arr
__import__("pdb").set_trace()
# Type p Q
# Type p Q - arr
# Raises ValueError

Это создает ситуацию, когда объекты obj и self в __array_finalize__ относятся к одному типу, но имеют разное количество элементов и это приводит к вопросам:

  • Как правильно реализовать нарезку для этого пользовательского подкласса, чтобы удалить дескриптор, связанный с полями, которые не находятся внутри среза (или список индексов)?
  • Сколько особых случаев мне нужно обработать в методе __array_finalize__?
...