странный синтаксис libtorch (PyTorch C ++) - PullRequest
5 голосов
/ 20 апреля 2020

В официальных примерах PyTorch C ++ на GitHub Здесь вы можете увидеть странное определение класса:

class CustomDataset : public torch::data::datasets::Dataset<CustomDataset> {...}

Насколько я понимаю, это определяет класс CustomDataset, который " наследуется от "или" расширяется "torch::data::datasets::Dataset<CustomDataset>. Это странно для меня, так как класс, который мы создаем, наследуется от другого класса, параметризованного классом, который мы создаем ... Как это вообще работает? Что это значит? Мне кажется, что класс Integer унаследован от vector<Integer>, что выглядит абсурдно.

1 Ответ

8 голосов
/ 20 апреля 2020

Это странно повторяющийся шаблон , или сокращенно CRTP. Основным преимуществом этого метода является то, что он включает так называемый stati c полиморфизм , что означает, что функции в torch::data::datasets::Dataset могут вызывать функции CustomDataset без необходимости превращения этих функций в виртуальные (и таким образом, иметь дело с беспорядком времени выполнения виртуальной отправки метода и так далее). Вы также можете выполнить метапрограммирование времени компиляции, например время компиляции enable_if с, в зависимости от свойств пользовательского типа набора данных.

В случае PyTorch, BaseDataset ( Суперкласс Dataset) интенсивно использует эту технику для поддержки таких операций, как сопоставление и фильтрация:

  template <typename TransformType>
  MapDataset<Self, TransformType> map(TransformType transform) & {
    return datasets::map(static_cast<Self&>(*this), std::move(transform));
  }

Обратите внимание на приведение stati c this к производному типу ( законно до тех пор, пока CRTP применяется должным образом); datasets::map создает объект MapDataset, который также параметризован типом набора данных, позволяя реализации MapDataset статически вызывать методы, такие как get_batch (или если в них возникает ошибка время компиляции не существует).

Кроме того, поскольку MapDataset получает пользовательский тип набора данных в качестве параметра типа, возможно метапрограммирование во время компиляции:

  /// The implementation of `get_batch()` for the stateless case, which simply
  /// applies the transform to the output of `get_batch()` from the dataset.
  template <
      typename D = SourceDataset,
      typename = torch::disable_if_t<D::is_stateful>>
  OutputBatchType get_batch_impl(BatchRequestType indices) {
    return transform_.apply_batch(dataset_.get_batch(std::move(indices)));
  }

  /// The implementation of `get_batch()` for the stateful case. Here, we follow
  /// the semantics of `Optional.map()` in many functional languages, which
  /// applies a transformation to the optional's content when the optional
  /// contains a value, and returns a new optional (of a different type)  if the
  /// original optional returned by `get_batch()` was empty.
  template <typename D = SourceDataset>
  torch::enable_if_t<D::is_stateful, OutputBatchType> get_batch_impl(
      BatchRequestType indices) {
    if (auto batch = dataset_.get_batch(std::move(indices))) {
      return transform_.apply_batch(std::move(*batch));
    }
    return nullopt;
  }

Обратите внимание, что условное разрешение зависит от SourceDataset, который у нас есть только потому, что набор данных параметризован с этим шаблоном CRTP.

...