Перегрузка операторов в NumPy - PullRequest
0 голосов
/ 07 ноября 2018

У меня следующая проблема: я хочу создать и использовать массив numpy с небольшим изменением в операторе []. Я так понимаю, что пока это делается методом __getitem__(self, index). Однако я не могу понять, как это сделать, поэтому я объявляю массив, который является «массивом пустяков» во всех аспектах, кроме одной этой проблемы (скажем, ради примера, я хочу, чтобы array[i] интерпретировалось как array[i-1]

Я пытался решить это следующим образом:

class myarray(np.ndarray):
def __getitem__(self, index):
    return self[index+1]

k = np.linspace(0, 10, 10).view(myarray)

хотя на самом деле это не работает

Ответы [ 2 ]

0 голосов
/ 07 ноября 2018

Благодаря ответу onodip я решил свою начальную проблему. Это немного отличалось от того, что я написал, я научился быть более конкретным в будущем (не спрашивать на примере).

Первоначально я хотел перебрать матрицу "в циклах" - сделать индекс n + 1 равным 0 и т. Д. Для всех индексов - другими словами, учитывая их форму по модулю.

import numpy as np


class myarray(np.ndarray):

def __getitem__(self, index): 

    if isinstance(index, tuple):
        new_index = tuple(index[i] % super(myarray, self).shape[i] for i in range(len(index)))
    else:
        new_index = index % super(myarray, self).shape[0]

    return super(myarray, self).__getitem__(new_index)


my_k = np.linspace(0, 10, 10).view(myarray)
print(my_k)
print(my_k[7])
print(my_k[17])

Это был отличный урок для меня. Спасибо всем за ответы и ваше время!

0 голосов
/ 07 ноября 2018

Есть две проблемы с вашим кодом. Во-первых, индекс может быть также tuple (не просто int). Другое дело, что в ответ на вашу функцию вы получаете элемент с [], который также использует getitem . Это приведет к бесконечной рекурсии. Вы должны использовать функцию родительского класса с super()

import numpy as np


class myarray(np.ndarray):

    def __getitem__(self, index):
        if isinstance(index, tuple):
            index = index[0] + 1,
        else:
            index += 1
        return super(myarray, self).__getitem__(index)

my_k = np.linspace(0, 10, 10).view(myarray)
k = np.linspace(0, 10, 10).view(np.ndarray)
print(my_k)
print(k)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...