Необязательные тензоры в расширении PyTorch C ++ - PullRequest
0 голосов
/ 14 февраля 2019

Я пишу расширение C ++ для Pytorch и использую API C ++ для этого.Для моей функции forward мне нужно передать дополнительный тензор.Внутри функции я хочу делать разные вещи в зависимости от того, был ли передан этот необязательный параметр или нет.В общем, мы используем NULL для необязательных аргументов указателя в C ++ и проверяем внутри функции, является ли указатель NULL или нет.Я не знаю, как это сделать для типа at::Tensor API Torch c ++.

void xyz_forward(
    const at::Tensor xyz1, 
    const at::Tensor xyz2, 
    const at::Tensor optional_constraints = something)
{
     if(optional_constraints){
        //do something
     }else{
        //do something else
     }
}

Обратите внимание, что я не могу сделать const at::Tensor optional_constraints = at::ones или что-то еще, потому что этот параметр может принимать любое реальное значение и может иметь различный размер / форму.Я не могу присвоить ему числовое значение в качестве необязательного аргумента.Есть ли NULL эквивалент для этого?

Ответы [ 2 ]

0 голосов
/ 15 февраля 2019

Можно использовать std :: необязательный как std::optional<at::Tensor> optional_constraints = std::nullopt.Он контекстно конвертируется в bool, поэтому вы можете проверить его с помощью if (optional_constraints).Используйте метод .value(), чтобы получить тензор, если вы его передадите, иначе значением по умолчанию будет std::nullopt.

0 голосов
/ 14 февраля 2019

Так как я не могу найти ничего похожего, например.OpenCV noArray() (который в основном используется для передачи оптинальных матриц, таких как маски) в API , я бы предложил вам использовать для этой цели перегруженную функцию

void xyz_forward(
    const at::Tensor xyz1, 
    const at::Tensor xyz2)
{
     // optional tensor wasnt passed
}

void xyz_forward(
    const at::Tensor xyz1, 
    const at::Tensor xyz2, 
    const at::Tensor optional_constraints)
{
     // optional tensor passed
}
...