Как использовать индекс одного массива для определения __getitem__ другого? - PullRequest
0 голосов
/ 10 декабря 2018

У меня есть объект пользовательского класса Field, который по сути обтекает объект numpy.ndarray.Объект определяется из двух входных данных: массива значений (values) и объекта среза (segment), который определяет, где эти значения должны быть помещены в некоторый больший массив (grid).

Я хотел бы иметь возможность использовать индексы grid для доступа к элементам values.Это должно быть возможно путем определения пользовательского метода Field.__getitem__.

import numpy as np

class Field:
    def __init__(self, values, segment, grid):
        if (not isinstance(segment, slice)) \\
        or (not isinstance(values, np.ndarray)) :
            raise TypeError
        if segment.step not in [1, -1]:
            raise ValueError('Segment must be continuous')
        if len(grid[segment]) != len(values):
            raise ValueError('values length must match segment')

        self.values = values
        self.segment = segment 
        self.grid = grid

    def __getitem__(self, key):
        new_key = ...  # <--- Code goes here
        return self.values[new_key]

grid = np.array([0.5, 1.5, 2.5, 3.5, 4.5])

values = np.array([42., 43., 44.])
segment = slice(2, 5)

my_field = Field(values, segment, grid)
print(grid[segment])  # output: [2.5, 3.5, 4.5]
print(my_field[2])  # Desired output: 42.
print(my_field[3])  # Desired output: 43.
print(my_field[0])  # Desired output: IndexError

Дело в том, что segment определяет набор позиций в grid, где определено my_field.Способ, которым я подошел к этому, оказался очень неэлегатным и неуклюжим и основывался на определении некоторого массива логических index = np.zeros_like(grid, dtype=bool); index[segment] = True, а затем включал некоторые трюки с np.cumsum(index) ...

Как я могудобиться этого поведения проще?

1 Ответ

0 голосов
/ 10 декабря 2018

Вы можете определить свой срез с явным шагом:

segment = slice(2, 5, 1)

Это должно обеспечить segment.step в ваших __init__ возвратах 1.Затем определите метод, который проверяет, находится ли ваш ввод key в соответствующем range:

def __getitem__(self, key):
    start, stop = self.segment.start, self.segment.stop
    new_key = key - start
    if new_key not in range(stop - start):
        raise IndexError(f'Key must be in range({start}, {stop})')
    return self.values[new_key]

. Это дает:

my_field = Field(values, segment, grid)
print(grid[segment])  # [2.5, 3.5, 4.5]
print(my_field[2])    # 42.0
print(my_field[3])    # 43.0
print(my_field[0])    # IndexError: Key must be in range(2, 5)
...