Реализуйте __eq__ для объектов Python, которые содержат глубоко вложенные массивы numpy - PullRequest
1 голос
/ 21 апреля 2019

У меня проблемы с тем, что числовые массивы не сопоставимы с == (используя семантику np.array_equal) в контексте атрибутов объекта.

Рассмотрим следующий пример:

>>> import numpy as np
>>> class A:
...     def __init__(self, a):
...         self.a = a
...     def __eq__(self, other):
...         return self.__dict__ == other.__dict__
...
>>> x = A(a=[1, np.array([1, 2])])
>>> y = A(a=[1, np.array([1, 2])])
>>> x == y
Traceback (most recent call last):
  File "<ipython-input-33-9cfbd892cdaa>", line 1, in <module>
    x == y
  File "<ipython-input-30-790950997d4f>", line 5, in __eq__
    return self.__dict__ == other.__dict__
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

(игнорируйте, что __eq__ не идеально, следует хотя бы проверить тип other, но это для краткости)

Как бы я внедрил __eq__функция, которая обрабатывает вложенные массивы, вложенные глубоко в мои атрибуты объекта (если предположить, что все остальное, например, список в этом примере, отлично сравнивается с ==)?Числовые массивы могут появляться на сколь угодно глубоком уровне вложенности внутри списков, кортежей или диктов.

Я пытался придумать "ручную" реализацию рекурсивной функции eq, которая применяет == квсе атрибуты и использует np.array_equal всякий раз, когда встречается с пустым массивом, но это хитрее, чем ожидалось.

У кого-нибудь есть подходящая функция или простой обходной путь?

1 Ответ

0 голосов
/ 21 апреля 2019

Если есть возможность изменить объекты x и y, вы можете перезаписать __eq__ метод np.ndarray в соответствии с вашими предпочтениями.

class eqarr(np.ndarray):
    def __eq__(self, other):
        return np.array_equal(self, other)

class A:
     def __init__(self, a):
         self.a = a
     def __eq__(self, other):
         return self.__dict__ == other.__dict__

x = A(a=[1, eqarr([1, 2])])
y = A(a=[1, eqarr([1, 2])])
x == y

Это перезапускается в True.

Если это невозможно, единственное решение, о котором я могу подумать в данный момент, - это реализовать рекурсивную функцию проверки равенства.Моя попытка заключается в следующем:

def eq(a, b):
    if not (hasattr(a, '__iter__') or type(a) == str):
        return a == b

    try:
        if not len(a) == len(b):
            return False

        if type(a) == np.ndarray:
            return np.array_equal(a, b)
        if isinstance(a, dict):
            return all(eq(v, b[k]) for k, v in a.items())
        else:
            return all(eq(aa, bb) for aa, bb in zip(a, b))
    except (TypeError, KeyError):
        return False


class A:
     def __init__(self, a):
         self.a = a
     def __eq__(self, other):
         return eq(self.__dict__, other.__dict__)

С вашими примерами и всеми теми, которые я придумал, это сработало.Решение должно быть применимо, когда вложенные объекты имеют атрибуты __iter__ и __len__.

Я надеюсь, что я учел все возможные ошибки, но вам может понадобиться немного изменить код, чтобы сделатьэто абсолютно отказоустойчиво.

Если вы найдете контрпример, пожалуйста, предоставьте его в качестве комментария.Я уверен, что код можно соответствующим образом скорректировать.

Производительность eq может быть не очень хорошей, но я не знаю, является ли это серьезной проблемой для вас.

Если массивы numyдовольно редки в вашей иерархии (и часто близки к началу), вы всегда можете сначала попробовать нормальное сравнение.Это может выглядеть следующим образом:

def eq(a, b):
    try:
        return np.all(a == b)
    except ValueError:
        pass

    try:
        if not len(a) == len(b):
            return False

        if type(a) == np.ndarray:
            return np.array_equal(a, b)
        if isinstance(a, dict):
            return all(eq(v, b[k]) for k, v in a.items())
        else:
            return all(eq(aa, bb) for aa, bb in zip(a, b))
    except (TypeError, KeyError):
        return False
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...