Мы можем определить n-мерную нулевую матрицу и заполнить ее диагональ, используя расширенное индексирование:
def to_nd_diagonal(x):
*i, j, k = x.shape
a = np.zeros((*i,j,j,k))
I = np.arange(j)
a[...,I,I,:] = x
out=a[...,None]
return out, out.swapaxes(-1,-2)
Где мы получим:
ej, ei = to_nd_diagonal(x)
ei.shape
# (2, 4, 4, 1, 3)
ej.shape
# (2, 4, 4, 3, 1)
print(ei)
array([[[[[ 0., 1., 2.]],
[[ 0., 0., 0.]],
[[ 0., 0., 0.]],
[[ 0., 0., 0.]]],
[[[ 0., 0., 0.]],
[[ 3., 4., 5.]],
[[ 0., 0., 0.]],
[[ 0., 0., 0.]]],
[[[ 0., 0., 0.]],
[[ 0., 0., 0.]],
[[ 6., 7., 8.]],
[[ 0., 0., 0.]]],
[[[ 0., 0., 0.]],
[[ 0., 0., 0.]],
[[ 0., 0., 0.]],
[[ 9., 10., 11.]]]],
[[[[12., 13., 14.]],
[[ 0., 0., 0.]],
[[ 0., 0., 0.]],
[[ 0., 0., 0.]]],
[[[ 0., 0., 0.]],
[[15., 16., 17.]],
[[ 0., 0., 0.]],
[[ 0., 0., 0.]]],
[[[ 0., 0., 0.]],
[[ 0., 0., 0.]],
[[18., 19., 20.]],
[[ 0., 0., 0.]]],
[[[ 0., 0., 0.]],
[[ 0., 0., 0.]],
[[ 0., 0., 0.]],
[[21., 22., 23.]]]]])
print(ej)
array([[[[[ 0.],
[ 1.],
[ 2.]],
[[ 0.],
[ 0.],
[ 0.]],
[[ 0.],
[ 0.],
[ 0.]],
[[ 0.],
[ 0.],
[ 0.]]],
[[[ 0.],
[ 0.],
[ 0.]],
[[ 3.],
[ 4.],
[ 5.]],
[[ 0.],
[ 0.],
[ 0.]],
[[ 0.],
[ 0.],
[ 0.]]],
[[[ 0.],
[ 0.],
[ 0.]],
[[ 0.],
[ 0.],
[ 0.]],
[[ 6.],
[ 7.],
[ 8.]],
[[ 0.],
[ 0.],
[ 0.]]],
[[[ 0.],
[ 0.],
[ 0.]],
[[ 0.],
[ 0.],
[ 0.]],
[[ 0.],
[ 0.],
[ 0.]],
[[ 9.],
[10.],
[11.]]]],
[[[[12.],
[13.],
[14.]],
[[ 0.],
[ 0.],
[ 0.]],
[[ 0.],
[ 0.],
[ 0.]],
[[ 0.],
[ 0.],
[ 0.]]],
[[[ 0.],
[ 0.],
[ 0.]],
[[15.],
[16.],
[17.]],
[[ 0.],
[ 0.],
[ 0.]],
[[ 0.],
[ 0.],
[ 0.]]],
[[[ 0.],
[ 0.],
[ 0.]],
[[ 0.],
[ 0.],
[ 0.]],
[[18.],
[19.],
[20.]],
[[ 0.],
[ 0.],
[ 0.]]],
[[[ 0.],
[ 0.],
[ 0.]],
[[ 0.],
[ 0.],
[ 0.]],
[[ 0.],
[ 0.],
[ 0.]],
[[21.],
[22.],
[23.]]]]])