Я пытаюсь реализовать numpy подкласс State
, который может обращаться к его элементам по атрибутам.
У меня есть дескриптор:
import numpy as np
from collections.abc import Sequence
from typing import Tuple, Union, Type
class _StateDescriptor:
""" This is a custom descriptor to be used within the State class. It
provides the possibility of accessing the ith-element of the numpy.ndarray
by name.
Parameters
---------
i
The index of the element in the numpy array to be accessed with
"""
def __init__(self, i):
self.i = i
def __get__(self, obj, objtype=None):
if obj is None:
return self
return obj[self.i]
def __set__(self, obj, value):
obj[self.i] = value
, а затем у меня есть class State
:
class State(np.ndarray):
""" :class:`State` is a subclass of :class:`numpy.ndarray`. It behaves like
a normal :class:`numpy.ndarray` except it can be initialized a bit more
expressively. In particular each element of the array can be accessed by
name (in the order variables are provided)
Attributes
----------
fields
Each element of the tuple is the name of the corresponding element in
the array
"""
fields: Tuple[str]
def __new__(cls, *args):
values = args
cls._set_descriptors(cls)
arr = np.asarray(list(values), dtype=float).view(cls)
return arr
def __getitem__(self, item):
self._slice_index = item
return super().__getitem__(item)
def __array_finalize__(self, obj):
# Normal __new__ construction
if obj is None:
return
elif self is not None: # Here we are a view or a new-from-template
# If we are a new-from-template, need also to slice the fields
# to the one selected by the slice
if isinstance(self, State) and (isinstance(obj, State)):
try:
slice_index = obj._slice_index
except AttributeError:
# When the _slice_index attribute is not available,
# we're in a strange state, probably numpy performing
# strange slicing for operations. We do nothing
return
cls = self.__class__
if isinstance(slice_index, Sequence):
# cls.fields is a tuple and can be indexed only by
# integer or slice. With a sequence we manually select
# the values specified by the list
newfields = []
# First we need to get the actual fields that
# are requested by the "slice"
for i in obj._slice_index:
newfields.append(cls.fields[i])
newfields = tuple(newfields)
else:
try:
newfields = self.fields[slice_index]
except AttributeError:
pass
self._sync_descriptors(newfields)
self._set_descriptors(self)
@staticmethod
def _set_descriptors(obj: Union[Type["State"], "State"]):
""" This method sets the descriptor to each variable of the state
provided by :class:_StateDescriptor for the class. Can be called
on the class or on a instance of :class:`State`
"""
if isinstance(obj, State):
cls = obj.__class__
if not (len(obj.fields) == len(obj)):
raise ValueError(
"The length of the array is not equal to the number of "
"its fields"
)
else:
cls = obj
for i, field in enumerate(obj.fields):
setattr(cls, field, _StateDescriptor(i))
def _sync_descriptors(self, newfields: Tuple[str]):
""" This method syncs the descriptors removing unused variables """
# We check on the full set of variables of the State
# which ones are requested by the "slice"
cls = self.__class__
for old_var in self.fields:
if not (old_var in newfields): # Requested
# We remove the associated attribute
delattr(cls, old_var)
self.fields = newfields
Похоже, что это работает для общих операций
QClass = State
QClass.fields = ("rho", "rhoU", "rhoV", "p")
Q = QClass(0, 0, 0, 0)
Q.rhoV = 12
Q_slice = Q[[0, 2]]
# Q_slice.rhoU # Raises AttributeError
# Q_slice.p # Raises AttributeError
assert Q_slice.rho == 0
assert Q_slice.rhoV == 12
Q = QClass(0, 0, 0, 0)
Q.rhoV = 12
Q_slice = Q[:3]
assert Q_slice.rhoV == 12
assert Q_slice.rhoV == Q_slice[-1]
Проблема возникает при выполнении широковещательных операций:
Кажется, что это останавливается работает ... случайно. Если я не использую pdb
, исключение не возникает.
arr = np.zeros((5, len(Q.fields)))
Q - arr
__import__("pdb").set_trace()
# Type p Q
# Type p Q - arr
# Raises ValueError
Это создает ситуацию, когда объекты obj
и self
в __array_finalize__
относятся к одному типу, но имеют разное количество элементов и это приводит к вопросам:
- Как правильно реализовать нарезку для этого пользовательского подкласса, чтобы удалить дескриптор, связанный с полями, которые не находятся внутри среза (или список индексов)?
- Сколько особых случаев мне нужно обработать в методе
__array_finalize__
?