Самый быстрый способ выполнения разреженных матричных умножений в Python - PullRequest
0 голосов
/ 19 апреля 2019

ПРЕДПОСЫЛКИ: Я пытаюсь построить имитационную модель барабана в реальном времени, для которой мне нужны действительно быстрые матрично-векторные продукты.Мои матрицы имеют размер ~ 5000-10000 строк / столбцов, из которых только 6 записей в строке ненулевые, поэтому я склонен использовать разреженные матрицы.Я использую scipy.sparse модуль.Итерации, как показано ниже.

Vjk_plus_sparse = Vjk_minus_sparse.transpose()
Vj = Vjk_plus_sparse.dot(constant)
np.put(Vj, Nr, 0.0)
Uj[t] = Uj[t-1] + np.transpose(Vj)/fs
Vj_mat = adj_mat_sparse.multiply(Vj)
Vjk_minus_sparse = Vj_mat-Vjk_plus_sparse.multiply(end_gain)

Здесь Vjk_plus_sparse, Vjk_minus_sparse и Vj_mat - это разреженные матрицы CSR, Vj - это массив numpy, а Uj - это матрица numpy, где каждая строкапредставляет Uj(t).end_gain представляет собой массив, который представляет собой статический массив numpy для гашения вибраций.

ВОПРОС: Одна итерация занимает около 3 мс для size = 4250.Самыми значительными шагами являются последние 2 строки.Они вместе занимают около 2,5 мс.В идеале мне нужно, чтобы он работал в 0.1 ms, что было бы более чем в 10 раз быстрее.Это максимальная степень векторизации, возможная для проблемы, и я не могу распараллелить, поскольку я иду вовремя, по крайней мере физически это не будет точным.

ПОПЫТКИ: я попытался поиграться с разреженными структурами данных,и нашел лучшую производительность со всеми из них CSR (Compressed Sparse Row) со значениями, как указано выше.Я также попытался заменить метод multiply() матричным умножением, повторив Vj, но это ухудшило время, поскольку результирующая операция была бы разреженной * плотной операцией.

Как я могу ускорить это в самом Python?Я также открыт для попытки использования c ++, хотя миграция сейчас была бы серьезной болью.Кроме того, поскольку scipy по существу основано на c, даст ли это даже такое ускорение?

Добавлен полный исполняемый пример

import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.patches
import math
from mpl_toolkits import mplot3d
import numpy as np
import scipy.sparse as sp
import scipy.fftpack as spf
import matplotlib.animation as animation
import time

sqrt_3 = 1.73205080757


class Pt:
    def __init__(self,x_0,y_0):
        self.x_0 = x_0
        self.y_0 = y_0
        self.id = -1
        self.neighbours = []
        self.distance = (x_0**2 + y_0**2)**0.5

class Circle:
    def __init__(self,radius,center):
        self.radius = radius
        self.center = center
        self.nodes = []

    def construct_mesh(self, unit):
        queue = [self.center]
        self.center.distance = 0
        curr_id = 0
        delta = [(1.,0.), (1./2, (3**0.5)/2),(-1./2, (3**0.5)/2),(-1.,0.), (-1./2,-(3**0.5)/2), (1./2,- (3**0.5)/2)]
        node_dict = {}
        node_dict[(self.center.x_0,self.center.y_0)] = curr_id
        self.nodes.append(self.center)
        curr_id+=1
        while len(queue)!=0:
            curr_pt = queue[0]
            queue.pop(0)
            # self.nodes.append(curr_pt)
            # curr_id+=1
            for i in delta:

                temp_pt = Pt(curr_pt.x_0 + 2*unit*i[0], curr_pt.y_0 + 2*unit*i[1])
                temp_pt.id = curr_id
                temp_pt.distance = (temp_pt.x_0 ** 2 + temp_pt.y_0 ** 2)**0.5          
                # curr_id+=1
                if (round(temp_pt.x_0,5), round(temp_pt.y_0,5)) not in node_dict and temp_pt.distance <= self.radius:
                    # print(temp_pt.x_0, temp_pt.y_0)
                    self.nodes.append(temp_pt)
                    node_dict[(round(temp_pt.x_0,5), round(temp_pt.y_0,5))] = curr_id
                    curr_id+=1
                    queue.append(temp_pt)
                    curr_pt.neighbours.append(temp_pt.id)

                elif temp_pt.distance <= self.radius:
                    curr_pt.neighbours.append(node_dict[round(temp_pt.x_0,5), round(temp_pt.y_0,5)])

        # print(node_dict)

    def plot_neighbours(self, pt):
        x = []
        y = []
        x.append(pt.x_0)
        y.append(pt.y_0)
        for i in (pt.neighbours):
            x.append(self.nodes[i].x_0)
            y.append(self.nodes[i].y_0)
        plt.scatter(x,y)
        plt.axis('scaled')

    def boundary_node_ids(self):
        boundary_nodes = []
        for j in range(len(self.nodes)):
            if(len(self.nodes[j].neighbours) < 6):
                boundary_nodes.append(j)
        return boundary_nodes

    def add_rim(self, boundary_node_ids, unit):
        c = self.center
        rim_ids = []
        N = len(self.nodes)
        for i in range(len(boundary_node_ids)):
            d = self.nodes[boundary_node_ids[i]].distance
            xp = self.nodes[boundary_node_ids[i]].x_0
            yp = self.nodes[boundary_node_ids[i]].y_0
            xnew = xp + xp*unit/d
            ynew = yp + yp*unit/d
            new_point = Pt(xnew, ynew)
            new_point.id = N + i
            rim_ids.append(N+i)
            self.nodes.append(new_point)
            self.nodes[boundary_node_ids[i]].neighbours.append(new_point.id)
            self.nodes[N+i].neighbours.append(boundary_node_ids[i])
        return rim_ids

def find_nearest_point(mesh, pt):
    distances_from_center = np.zeros(len(mesh.nodes))
    for i in xrange(len(mesh.nodes)):
        distances_from_center[i] = mesh.nodes[i].distance
    target_distance = pt.distance
    closest_point_id = np.argmin(np.abs(distances_from_center-target_distance))
    return closest_point_id

def init_impulse(mesh, impulse,  Vj, poi, roi):
    data = []   
    for i in range(len(Vj)):
        r = ((mesh.nodes[i].x_0 - mesh.nodes[poi].x_0)**2 + (mesh.nodes[i].y_0 - mesh.nodes[poi].y_0)**2)**0.5
        Vj[i] = max(0, impulse*(1. - (r/roi)))
        if i in Nr:
            Vj[i] = 0.
        for k in mesh.nodes[i].neighbours:
            data.append(np.asscalar(Vj[i])/2.)

    return Vj, data


r = 0.1016                                 #Radius of drum head
# rho = 2500                          #Density of drum head
thickness = 0.001                       #Thickness of membrane
# tension = 1500                     #Tension in membrane in N
param = 0.9
c = (param/thickness)**(0.5)    #Speed of wave in string
duration = 0.25
fs = 4000
delta = c/fs

center = Pt(0,0)
point_of_impact = Pt(r/2., 0)
center.id = 0
mesh = Circle(r,center)
mesh.construct_mesh(delta)
N = len(mesh.nodes)

Nb = []
for j in range(N):
    if len(mesh.nodes[j].neighbours) < 6:
        Nb.append(j)

Nr = mesh.add_rim(Nb, delta)

N = len(mesh.nodes)
print(N)

row_ind = []
col_ind = []

for j in range(N):
    for k in mesh.nodes[j].neighbours:
        row_ind.append(j)
        col_ind.append(k)

data = np.ones(len(col_ind))

adj_mat_sparse = sp.csr_matrix((data, (row_ind, col_ind)), shape = (N,N))

Vjk_plus = sp.csr_matrix([N, N])
Vj = np.zeros([N,1])
Uj = np.zeros([int(duration*fs), N])
Vj_mat = sp.csc_matrix([N,N])

closest_point_id = find_nearest_point(mesh, point_of_impact)
Vj, Vjk_data = init_impulse(mesh, -10.0, Vj, closest_point_id, r/10.)
Vjk_minus_sparse = sp.csr_matrix((Vjk_data, (row_ind, col_ind)), shape = (N,N))
constant = (1./3)*np.ones([N,1])


Vjk_plus = Vjk_minus_sparse.transpose()
np.put(Vj, Nr, 0.0)
Uj[1] = Uj[0] + np.transpose(Vj)/fs
Vj_mat = adj_mat_sparse.multiply(Vj)
Vjk_minus_sparse = Vj_mat - Vjk_plus

end_gain = np.ones([N,1])
end_gain[Nr] = 1.0          

for t in range(2,int(duration*fs)): 
    Vjk_plus = Vjk_minus_sparse.transpose()
    Vj = Vjk_plus.dot(constant)
    np.put(Vj, Nr, 0.0)
    Uj[t] = Uj[t-1] + np.transpose(Vj)/fs
    Vj_mat = adj_mat_sparse.multiply(Vj)
    Vjk_minus_sparse = Vj_mat-Vjk_plus.multiply(end_gain)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...