Скудная разреженная матрица возвращает IndexError - PullRequest
1 голос
/ 17 мая 2019

Если я пытаюсь нарезать разреженную матрицу или увидеть значение в заданном [row,colum], я получаю IndexError

Точнее, у меня есть следующее scipy.sparse.csr_matrix, которое я загружаю из файлапосле сохранения

...
>>> A = scipy.sparse.csr_matrix((vals, (rows, cols)), shape=(output_dim, input_dim))
>>> np.save(open('test_matrix.dat', 'wb'), A)
...
>>> A = np.load('test_matrix.dat', allow_pickle=True)
>>> A
array(<831232x798208 sparse matrix of type '<class 'numpy.float32'>'
    with 109886100 stored elements in Compressed Sparse Row format>,
      dtype=object)

Однако, когда я пытаюсь получить значение для данной пары [строка, столбец], я получаю следующую ошибку

>>> A[1,1]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
IndexError: too many indices for array

Почему это происходит?

Просто чтобы уточнить, я уверен, что матрица не пуста, так как я могу видеть ее содержание, если я сделаю

>>> print(A)
  (0, 1)    0.24914551
  (0, 2)    0.6669922
  (1, 1)    0.75097656
  (1, 3)    0.6640625
  (2, 3)    0.3359375
  (2, 514)  0.34960938
...

1 Ответ

0 голосов
/ 17 мая 2019

Когда вы сохраняете и перезагружаете свой разреженный массив, вы создали массив с одной записью; объект, являющийся вашим разреженным массивом. Так что А ничего не имеет в [1,1]. Вы должны использовать scipy.sparse.save_npz вместо.

Например:

import scipy.sparse as sps
import numpy as np

A = sps.csr_matrix((10,10))
A
<10x10 sparse matrix of type '<class 'numpy.float64'>'
    with 0 stored elements in Compressed Sparse Row format>
np.save('test_matrix.dat', A)
B = np.load('test_matrix.dat.npy', allow_pickle=True)
B
array(<10x10 sparse matrix of type '<class 'numpy.float64'>'
    with 0 stored elements in Compressed Sparse Row format>, dtype=object)
B[1,1]
IndexError                                Traceback (most recent call last)
<ipython-input-101-969f8bd5206a> in <module>
----> 1 B[1,1]

IndexError: too many indices for array
sps.save_npz('sparse_dat')
C = sps.load_npz('sparse_dat.npz')
C
<10x10 sparse matrix of type '<class 'numpy.float64'>'
    with 0 stored elements in Compressed Sparse Row format>
C[1,1]
0.0

Имейте в виду, что вы все еще можете получить A из B примерно так:

D = B.tolist()
D
<10x10 sparse matrix of type '<class 'numpy.float64'>'
    with 0 stored elements in Compressed Sparse Row format>
D[1,1]
0.0
...