Pytorch: метод быстрого создания тензора оценок - PullRequest
1 голос
/ 12 июля 2020

Я новичок в Pytorch и ищу функцию быстрого получения очков. Это, учитывая набор образцов и распределение, выводит тензор, состоящий из соответствующей оценки для каждой отдельной выборки. Например, рассмотрим следующий код:

norm = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(2),torch.eye(2))
samples = norm.sample((1000,))
samples.requires_grad_(True)

Используя samples, я хотел бы создать тензор score размером [1000,2], где i-й компонент score[i] - это градиент log p(samples[i]) , где p - плотность данного распределения. Я придумал следующий метод:

def get_score(samples,distribution):
    log_probs = distribution.log_prob(samples)
    for i in range(log_probs.size()[0]):
        log_probs[i].backward(retain_graph = True)

В результате получается тензор score samples.grad. Проблема в том, что мой метод довольно медленный для больших выборок (например, для выборки размером [50000,2] это занимает около 25-30 секунд на моем процессоре). Это настолько быстро, насколько это возможно?

Единственная альтернатива, которую я могу придумать, - это жестко закодировать функцию оценки для каждого распределения, которое я буду использовать, это не похоже на хорошее решение!

По результатам экспериментов, для 50000 образцов следующее примерно на 50% быстрее:

for i in range(50000):
    sample = norm.sample((1,))
    sample.requires_grad_(True)
    log_prob = norm.log_prob(a)
    log_prob.backward()

Это указывает на то, что должен быть лучший способ!

1 Ответ

1 голос
/ 12 июля 2020

Я предполагаю, что log_probs хранится как тензор pytorch. Вы можете воспользоваться линейностью дифференцирования для вычисления производной сразу для всех выборок: log_probs.sum().backward(retain_graph = True)

По крайней мере, с ускорением графического процессора это будет намного быстрее.

Если log_probs не тензор, но список скаляров (представленных как тензоры pytorch ранга 0), вы можете сначала использовать log_probs = torch.stack(log_probs).

...