Многомерная диагональная матрица, где диагональные элементы - это векторы строк или столбцов - PullRequest
1 голос
/ 23 марта 2020

У меня есть трехмерная матрица x с формой (2,4,6)

x = np.arange(2*4*3).reshape(2,4,3)
array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]],

       [[12, 13, 14],
        [15, 16, 17],
        [18, 19, 20],
        [21, 22, 23]]])

Я хочу создать две диагональные матрицы ei и ej с формой (2,4,4) , 1,3) и (2,4,4,3,1) такие, что диагональные элементы ei и ej являются векторами строк и столбцов-векторов элементов в x.

0-й элемент ei: ei [0]

0-й элемент ej: ej [0]

Например

ei = array([[
        [[[ 0.,  1.,  2.]],
         [[ 0.,  0.,  0.]],
         [[ 0.,  0.,  0.]],
         [[ 0.,  0.,  0.]]],


        [[[ 0.,  0.,  0.]],
         [[ 3.,  4.,  5.]],
         [[ 0.,  0.,  0.]],
         [[ 0.,  0.,  0.]]],

        ...

        [[[ 0.,  0.,  0.]],
         [[ 0.,  0.,  0.]],
         [[ 0.,  0.,  0.]],
         [[21., 22., 23.]]]]])

My Наивный способ приведен ниже

ei = np.zeros([x.shape[0],x.shape[1],x.shape[1],1,x.shape[2]])
ej = np.zeros([x.shape[0],x.shape[1],x.shape[1],x.shape[2],1])

for j in range(x.shape[0]):
  for i in range(x.shape[1]):
    ei[j,i,i] = x[j,i]
    ej[j,i,i] = np.transpose([x[j,i]])

Есть ли альтернативный способ сделать то же самое?

Ответы [ 2 ]

0 голосов
/ 23 марта 2020

Мы можем определить n-мерную нулевую матрицу и заполнить ее диагональ, используя расширенное индексирование:

def to_nd_diagonal(x):
    *i, j, k = x.shape
    a = np.zeros((*i,j,j,k))
    I = np.arange(j)
    a[...,I,I,:] = x
    out=a[...,None]
    return out, out.swapaxes(-1,-2)

Где мы получим:

ej, ei = to_nd_diagonal(x)

ei.shape
# (2, 4, 4, 1, 3)
ej.shape
# (2, 4, 4, 3, 1)

print(ei)

array([[[[[ 0.,  1.,  2.]],

         [[ 0.,  0.,  0.]],

         [[ 0.,  0.,  0.]],

         [[ 0.,  0.,  0.]]],


        [[[ 0.,  0.,  0.]],

         [[ 3.,  4.,  5.]],

         [[ 0.,  0.,  0.]],

         [[ 0.,  0.,  0.]]],


        [[[ 0.,  0.,  0.]],

         [[ 0.,  0.,  0.]],

         [[ 6.,  7.,  8.]],

         [[ 0.,  0.,  0.]]],


        [[[ 0.,  0.,  0.]],

         [[ 0.,  0.,  0.]],

         [[ 0.,  0.,  0.]],

         [[ 9., 10., 11.]]]],



       [[[[12., 13., 14.]],

         [[ 0.,  0.,  0.]],

         [[ 0.,  0.,  0.]],

         [[ 0.,  0.,  0.]]],


        [[[ 0.,  0.,  0.]],

         [[15., 16., 17.]],

         [[ 0.,  0.,  0.]],

         [[ 0.,  0.,  0.]]],


        [[[ 0.,  0.,  0.]],

         [[ 0.,  0.,  0.]],

         [[18., 19., 20.]],

         [[ 0.,  0.,  0.]]],


        [[[ 0.,  0.,  0.]],

         [[ 0.,  0.,  0.]],

         [[ 0.,  0.,  0.]],

         [[21., 22., 23.]]]]])

print(ej)

array([[[[[ 0.],
          [ 1.],
          [ 2.]],

         [[ 0.],
          [ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.],
          [ 0.]]],


        [[[ 0.],
          [ 0.],
          [ 0.]],

         [[ 3.],
          [ 4.],
          [ 5.]],

         [[ 0.],
          [ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.],
          [ 0.]]],


        [[[ 0.],
          [ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.],
          [ 0.]],

         [[ 6.],
          [ 7.],
          [ 8.]],

         [[ 0.],
          [ 0.],
          [ 0.]]],


        [[[ 0.],
          [ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.],
          [ 0.]],

         [[ 9.],
          [10.],
          [11.]]]],



       [[[[12.],
          [13.],
          [14.]],

         [[ 0.],
          [ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.],
          [ 0.]]],


        [[[ 0.],
          [ 0.],
          [ 0.]],

         [[15.],
          [16.],
          [17.]],

         [[ 0.],
          [ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.],
          [ 0.]]],


        [[[ 0.],
          [ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.],
          [ 0.]],

         [[18.],
          [19.],
          [20.]],

         [[ 0.],
          [ 0.],
          [ 0.]]],


        [[[ 0.],
          [ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.],
          [ 0.]],

         [[21.],
          [22.],
          [23.]]]]])
0 голосов
/ 23 марта 2020

Заменить петли range на индексирование np.arange:

In [260]: arr = np.arange(12).reshape(3,4)                                                                           
In [261]: res = np.zeros((3,3,4),int)                                                                                
In [262]: res[np.arange(3), np.arange(3),:] = arr                                                                    
In [263]: res                                                                                                        
Out[263]: 
array([[[ 0,  1,  2,  3],
        [ 0,  0,  0,  0],
        [ 0,  0,  0,  0]],

       [[ 0,  0,  0,  0],
        [ 4,  5,  6,  7],
        [ 0,  0,  0,  0]],

       [[ 0,  0,  0,  0],
        [ 0,  0,  0,  0],
        [ 8,  9, 10, 11]]])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...