Если есть возможность изменить объекты x
и y
, вы можете перезаписать __eq__
метод np.ndarray
в соответствии с вашими предпочтениями.
class eqarr(np.ndarray):
def __eq__(self, other):
return np.array_equal(self, other)
class A:
def __init__(self, a):
self.a = a
def __eq__(self, other):
return self.__dict__ == other.__dict__
x = A(a=[1, eqarr([1, 2])])
y = A(a=[1, eqarr([1, 2])])
x == y
Это перезапускается в True
.
Если это невозможно, единственное решение, о котором я могу подумать в данный момент, - это реализовать рекурсивную функцию проверки равенства.Моя попытка заключается в следующем:
def eq(a, b):
if not (hasattr(a, '__iter__') or type(a) == str):
return a == b
try:
if not len(a) == len(b):
return False
if type(a) == np.ndarray:
return np.array_equal(a, b)
if isinstance(a, dict):
return all(eq(v, b[k]) for k, v in a.items())
else:
return all(eq(aa, bb) for aa, bb in zip(a, b))
except (TypeError, KeyError):
return False
class A:
def __init__(self, a):
self.a = a
def __eq__(self, other):
return eq(self.__dict__, other.__dict__)
С вашими примерами и всеми теми, которые я придумал, это сработало.Решение должно быть применимо, когда вложенные объекты имеют атрибуты __iter__
и __len__
.
Я надеюсь, что я учел все возможные ошибки, но вам может понадобиться немного изменить код, чтобы сделатьэто абсолютно отказоустойчиво.
Если вы найдете контрпример, пожалуйста, предоставьте его в качестве комментария.Я уверен, что код можно соответствующим образом скорректировать.
Производительность eq
может быть не очень хорошей, но я не знаю, является ли это серьезной проблемой для вас.
Если массивы numyдовольно редки в вашей иерархии (и часто близки к началу), вы всегда можете сначала попробовать нормальное сравнение.Это может выглядеть следующим образом:
def eq(a, b):
try:
return np.all(a == b)
except ValueError:
pass
try:
if not len(a) == len(b):
return False
if type(a) == np.ndarray:
return np.array_equal(a, b)
if isinstance(a, dict):
return all(eq(v, b[k]) for k, v in a.items())
else:
return all(eq(aa, bb) for aa, bb in zip(a, b))
except (TypeError, KeyError):
return False