Pytorch наиболее эффективный расчет Якобиана / Гессена - PullRequest
1 голос
/ 06 июня 2019

Я ищу наиболее эффективный способ получить якобиан функции через Pytorch и до сих пор нашел следующие решения:

def func(X):
    return torch.stack((
                     X.pow(2).sum(1),
                     X.pow(3).sum(1),
                     X.pow(4).sum(1)
                      ),1)  

X = Variable(torch.ones(1,int(1e5))*2.00094, requires_grad=True).cuda()

                                # Solution 1:
t = time()
Y = func(X)
J = torch.zeros(3, int(1e5))
for i in range(3):
    J[i] = grad(Y[0][i], X, create_graph=True, retain_graph=True, allow_unused=True)[0]
print(time()-t)
Output: 0.002 s

                                # Solution 2:
def Jacobian(f,X):
    X_batch = Variable(X.repeat(3,1), requires_grad=True)
    f(X_batch).backward(torch.eye(3).cuda(),  retain_graph=True)
    return X_batch.grad

t = time()
J2 = Jacobian(func,X)
print(time()-t)
Output: 0.001 s

Поскольку, как представляется, нет большой разницы междуиспользуя цикл в первом решении, а не во втором, я хотел спросить, может ли еще быть более быстрый способ вычисления якобиана в pytorch.эффективный способ вычисления гессиана.

Наконец, кто-нибудь знает, можно ли что-то подобное сделать проще или эффективнее в TensorFlow?

1 Ответ

0 голосов
/ 15 июня 2019

У меня была похожая проблема, которую я решил, определив якобиан вручную (вычисляя производные вручную). Для моей проблемы это было выполнимо, но я могу себе представить, что это не всегда так. Время вычислений ускоряет некоторые факторы на моей машине по сравнению со вторым решением. Однако я не могу использовать вашу функцию grad (...) в первом решении.

# Solution 2
def Jacobian(f,X):
    X_batch = Variable(X.repeat(3,1), requires_grad=True)
    f(X_batch).backward(torch.eye(3).cuda(),  retain_graph=True)
    return X_batch.grad

%timeit Jacobian(func,X)
11.7 ms ± 130 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

# Solution 3
def J_func(X):
    return torch.stack(( 
                 2*X,
                 3*X.pow(2),
                 4*X.pow(3)
                  ),1)

%timeit J_func(X)
539 µs ± 24.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
...