Python: параллельное матричное умножение в нейронной сети без np.dot или np.matmul - PullRequest
0 голосов
/ 03 января 2019

Я хочу создать код, который может вычислять умножение матриц в нейронной сети без использования тензорного потока или np.dot или np.matmul.

Следующий фрагмент кода, который меня интересует:

class Affine:
def __init__(self, W, b):
    self.W = W
    self.b = b
    self.x = None
    self.original_x_shape = None
    self.dW = None
    self.db = None

def forward(self, x):
    self.original_x_shape = x.shape
    x = x.reshape(x.shape[0], -1)
    self.x = x
    out = np.dot(self.x, self.W) + self.b
    return out

Код является частью форвардного расчета нейронной сети (X * W + b). И это хорошо работает.

Я хочу изменить строку out = np.dot(self.x, self.W) + self.b. Он должен работать точно так же без np.dot или np.matmul.

Вот мой код:

class Affine2:
def __init__(self, W, b):
    self.W = W
    self.b = b
    self.x = None
    self.original_x_shape = None
    self.dW = None
    self.db = None

def forward(self, x):
    self.original_x_shape = x.shape
    x = x.reshape(x.shape[0], -1)
    self.x = x

    rows_A = len(self.x)
    cols_A = len(self.x[0])
    rows_B = len(self.W)
    cols_B = len(self.W[0])

    if cols_A != rows_B:
        print("Cannot multiply the two matrices. Incorrect dimensions.")
        return

    # Create the result matrix
    start_time = time.time()
    out = np.zeros((rows_A, cols_B))

    def matmult(i):
        time.sleep(1)
    # for i in range(rows_A):
        for j in range(cols_B):
            for k in range(cols_A):
                out[i][j] += self.x[i][k] * self.W[k][j]

    if __name__ == '__main__':
        pool = Pool(process_num)
        start_time = int(time.time())

        pool.map(matmult, range(0, rows_A))
        print("Seconds: %s" % (time.time()-start_time))

    return out

Модифицированная часть - это просто параллельное матричное умножение. Однако произошла следующая ошибка: AttributeError: Can't pickle local object 'Affine2.forward.<locals>.matmult'

Как мне решить проблему?

...