Мне удалось добиться результата, который вы хотите адаптировать, который вы связали:
from scipy.sparse import csr_matrix
a = [[4, 0, 0], [0, 3, 0], [0, 0, 1]]
a = csr_matrix(a)
idx = a.argmax(axis=0)
m = a.shape[1]
a[idx,np.arange(m)[None,:]].toarray()
Выходы:
array([[4, 3, 1]], dtype=int32)