Где torch.cholesky и как torch относится к его методам? - PullRequest
0 голосов
/ 30 января 2020

Я занимаюсь исследованием разложения Холецкого, которое требует некоторого понимания того, как работает torch.cholesky. Через некоторое время поиска и поиска в ATen я застрял в TensorMethods.h, который интересно имеет следующий код:

inline Tensor Tensor::cholesky(bool upper) const {
#ifdef USE_STATIC_DISPATCH
    return TypeDefault::cholesky(const_cast<Tensor&>(*this), upper);
#else
    static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::cholesky", ""}).value();
    return c10::Dispatcher::singleton().callUnboxed<Tensor, const Tensor &, bool>(
        op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast<Tensor&>(*this), upper);
#endif
}

Это подняло вопрос о том, как torch находит свои методы. Спасибо!

1 Ответ

2 голосов
/ 30 января 2020

Взгляните на aten / src / ATen / native / README.md , который описывает, как функции регистрируются в API.

ATen "родными" функциями являются современный механизм добавления операторов и функций в ATen (они являются «нативными» в отличие от устаревших функций, которые связаны через метаданные TH / TH C cwrap). Собственные функции объявлены в native_functions.yaml и имеют реализации, определенные в одном из cpp файлов в этом каталоге.

Если мы посмотрим на aten / src / ATen / native / native_functions.yaml и для поиска cholesky мы находим

- func: cholesky(Tensor self, bool upper=False) -> Tensor
  use_c10_dispatcher: full
  variants: method, function

Чтобы найти точку входа, вам просто нужно искать файлы. cpp в aten / src / ATen / native и найдите функцию с именем cholesky. В настоящее время его можно найти по адресу BatchLinearAlgebra. cpp: 550

Tensor cholesky(const Tensor &self, bool upper) {
  if (self.size(-1) == 0) {
    return at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  }
  squareCheckInputs(self);

  auto raw_cholesky_output = at::_cholesky_helper(self, upper);
  if (upper) {
    return raw_cholesky_output.triu_();
  } else {
    return raw_cholesky_output.tril_();
  }
}

С этого момента достаточно понять код C ++, чтобы понять, что происходит.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...