Восстановление правильности const для прямого прохода NN - PullRequest
0 голосов
/ 18 апреля 2019

Я пытаюсь реализовать простую нейронную сеть, используя pytorch / libtorch.Следующий пример адаптирован из руководства по интерфейсу libtorch cpp .

#include <torch/torch.h>
struct DeepQImpl : torch::nn::Module {
    DeepQImpl(size_t N)
        : linear1(2,5),
          linear2(5,3) {}
    torch::Tensor forward(torch::Tensor x) const {
        x = torch::tanh(linear1(x));
        x = linear2(x);
        return x;
    }
    torch::nn::Linear linear1, linear2;
};
TORCH_MODULE(DeepQ);

Обратите внимание, что функция forward объявлена ​​const.Код, который я пишу, требует, чтобы оценка NN была константной функцией, что мне кажется разумным.Этот код не компилируется, хотя.Компилятор выдает

ошибка: нет совпадения для вызова '(const torch :: nn :: Linear) (at :: Tensor &)'
x = linear1 (x);

Я нашел способ обойти это, определив слои как mutable:

#include <torch/torch.h>
struct DeepQImpl : torch::nn::Module {
    /* all the code */
    mutable torch::nn:Linear linear1, linear2;
};

Так что мой вопрос

  1. Почемунанесение слоя на тензор не const
  2. Использует ли mutable способ исправить это и безопасно ли это?

Моя интуиция заключается в том, что в прямом проходе,слои собираются в структуру, которая может использоваться для обратного распространения, что требует некоторой операции записи.Если это правда, возникает вопрос, как собрать слои на первом (не const) шаге, а затем оценить структуру на втором (const) шаге.

...