Numpy: создание класса симметричной матрицы - PullRequest
0 голосов
/ 03 мая 2011

на основе этого ответа Я кодировал простой класс для симметричных матриц в python с numpy, но у меня возникла (возможно, очень простая) проблема. Это проблемный код:

import numpy as np

 class SyMatrix(np.ndarray):
    def __init__(self, arr):
        self = (arr + arr.T)/2.0 - np.diag(np.diag(arr)) 
    def __setitem__(self,(i,j), val):
        np.ndarray.__setitem__(self, (i, j), value)
        np.ndarray.__setitem__(self, (j, i), value)

Помимо этого неправильного ощущения (я не знаю, является ли хорошая рекомендация для self ...) Когда я пытаюсь создать новый массив, я получаю следующее:

>>> foo = SyMatrix( np.zeros(shape = (2,2)))
Traceback (most recent call last):
   File "<stdin>", line 1, in <module>
TypeError: only length-1 arrays can be converted to Python scalars

Я тоже пробовал:

import numpy as np

 class SyMatrix(np.ndarray):
    def __init__(self, n):
        self =  np.zeros(shape = (n,n)).view(SyMatrix)  
    def __setitem__(self,(i,j), val):
        np.ndarray.__setitem__(self, (i, j), value)
        np.ndarray.__setitem__(self, (j, i), value)

И тогда я получаю:

>>> foo = SyMatrix(2)
>>> foo
SyMatrix([  6.93581448e-310,   2.09933710e-316])
>>> 

где я ожидал массив с shape=(2,2). Как правильно делать то, что я пытаюсь сделать? Назначение на self проблематично?

1 Ответ

3 голосов
/ 03 мая 2011

Здесь есть несколько проблем.

  1. Когда подклассы numpy.ndarray(), вы должны перезаписать __new__(), а не __init__(). Ваша линия

    foo = SyMatrix(2)
    

    фактически вызывает numpy.ndarray.__new__() с параметром 2, несовместимым с его подписью .

  2. Назначение self здесь абсолютно ничего не делает. Он просто создает объект и заставляет локальное имя self указывать на этот объект. Как только функция завершается, все локальные имена удаляются. Присвоение в Python ни не создает переменные , ни не изменяет объекты ; он просто присваивает имя существующему объекту.

  3. Даже при устранении последних двух проблем ваш класс симметричной матрицы не будет работать должным образом. Существует буквально десятки методов, которые вам необходимо перезаписать, чтобы матрица всегда была симметричной.

  4. (arr + arr.T)/2.0 - np.diag(np.diag(arr)) скорее всего, не то, что вы хотите. На диагонали всегда будут нули. Вы, вероятно, хотите (arr + arr.T)/2.0.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...