Эффективная реализация факторизации машины с матричными операциями? - PullRequest
0 голосов
/ 02 июля 2018

Ссылка здесь: https://www.csie.ntu.edu.tw/~r01922136/slides/ffm.pdf (слайды 5-6)

Даны следующие матрицы:

X : n * d 
W : d * k

Существует ли эффективный способ вычисления матрицы n x 1, используя только матричные операции (например, numpy, tenorflow), где j-й элемент:

EDIT: Текущая попытка такова, но, очевидно, она не очень экономит место, так как требует хранения матриц размера n*d*d:

n = 1000
d = 256
k = 32

x = np.random.normal(size=[n,d])
w = np.random.normal(size=[d,k])

xxt = np.matmul(x.reshape([n,d,1]),x.reshape([n,1,d]))
wwt = np.matmul(w.reshape([1,d,k]),w.reshape([1,k,d]))
output = xxt*wwt
output = np.sum(output,(1,2))

1 Ответ

0 голосов
/ 05 июля 2018

Избегайте больших временных массивов

Не все типы алгоритмов так просто или очевидно векторизовать. np.sum(xxt*wwt) можно переписать с помощью np.einsum. Это должно быть быстрее, чем ваше решение, но имеет некоторые другие ограничения (например, нет многопоточности).

Я бы предложил использовать такой компилятор, как Numba.

Пример

import numpy as np
import numba as nb
import time

@nb.njit(fastmath=True,parallel=True)
def factorization_nb(w,x):
  n = x.shape[0]
  d = x.shape[1]
  k = w.shape[1]

  output=np.empty(n,dtype=w.dtype)
  wwt=np.dot(w.reshape((d,k)),w.reshape((k,d)))

  for i in nb.prange(n):
    sum=0.
    for j in range(d):
      for jj in range(d):
        sum+=x[i,j]*x[i,jj]*wwt[j,jj]
    output[i]=sum
  return output

def factorization_orig(w,x):
  n = x.shape[0]
  d = x.shape[1]
  k = w.shape[1]

  xxt = np.matmul(x.reshape([n,d,1]),x.reshape([n,1,d]))
  wwt = np.matmul(w.reshape([1,d,k]),w.reshape([1,k,d]))
  output = xxt*wwt
  output = np.sum(output,(1,2))

  return output

Производительность измерений

n = 1000
d = 256
k = 32

x = np.random.normal(size=[n,d])
w = np.random.normal(size=[d,k])

#first call has some compilation overhead
res_1=factorization_nb(w,x)
t1=time.time()
for i in range(100):
  res_1=factorization_nb(w,x)
  #res_2=factorization_orig(w,x)

print(time.time()-t1)

Задержка

factorization_nb: 4.2 ms per iteration
factorization_orig: 460 ms per iteration (110x speedup)
...