Я пытаюсь реализовать вычитание baseline als в pytorch, чтобы я мог запустить его на своем GPU, но у меня возникают проблемы, потому что pytorch.gesv дает результат, отличный от scipy.linalg.spsolve. Вот мой код для scipy:
def baseline_als(y, lam, p, niter=10):
L = len(y)
D = sparse.diags([1,-2,1],[0,-1,-2], shape=(L,L-2))
w = np.ones(L)
for i in range(niter):
W = sparse.spdiags(w, 0, L, L)
Z = W + lam * D.dot(D.transpose())
z = spsolve(Z, w*y)
w = p * (y > z) + (1-p) * (y < z)
return z
и вот мой код для pytorch
def baseline_als_pytorch(y, lam, p, niter=10):
diag = torch.tensor(np.repeat(1, L))
diag = torch.diag(diag, 0)
diag_minus_one = torch.tensor(np.repeat(-2, L - 1))
diag_minus_one = torch.diag(diag_minus_one, -1)
diag_minus_two = torch.tensor(np.repeat(1, L - 2))
diag_minus_two = torch.diag(diag_minus_two, -2)
D = diag + diag_minus_one + diag_minus_two
D = D[:, :L - 2].double()
w = torch.tensor(np.repeat(1, L)).double()
for i in range(10):
W = diag.double()
Z = W + lam * torch.mm(D, D.permute(1, 0))
z = torch.gesv(w * y, Z)
z = z[0].squeeze()
w = p * (y > z).double() + (1 - p) * (y < z).double()
return z
Извините, что код pytorch выглядит так плохо, я только начинаю.
Я подтвердил, что Z, w и y одинаково входят в scipy и pytorch и что z различается между ними сразу после того, как я пытаюсь решить систему уравнений.
Спасибо за комментарий, вот пример:
Я использую 100000 для lam и 0.001 для p.
Использование фиктивного ввода: y = (5,5,5,5,5,10,10,5,5,5,10,10,5,5,5,5,5,5,5 ),
я (3,68010263, 4,90344214, 6,12679489, 7,35022406, 8,57384278, 9,79774074, +11,02197199, 12,2465927, +13,47164891, 14.69711435,15.92287813, +17,14873257, +18,37456982, +19,60038184, 20.82626043,22.05215157, +23,27805103, +24,50400438, +25,73010693, +26,95625922) от SciPy и
(6,4938312, 6,46912395, 6,44440175, 6,41963499, 6.39477958,6.36977727, 6,34455582, 6,31907933, 6,29334844, 6,26735058, 6,24106029, 6,21443939, 6,18748732, 6,16024137, 6.13277694,6.10515785, 6,07743658, 6,04965455, 6,02184242, 5,99402035) от pytorch.
Это всего за одну итерацию цикла. Сципи прав, а питор нет.
Интересно, что если я использую краткий фиктивный ввод (5,5,5,5,5,10,10,5,5,5), я получу один и тот же ответ от обоих. Мой реальный вклад - 1011 размерность.