Пользовательские подмодули в pytorch / libtorch C ++ - PullRequest
1 голос
/ 30 апреля 2020

Полное раскрытие, я задавал этот же вопрос на форумах PyTorch около нескольких дней go и не получил ответа, так что это технически репост, но я считаю, что это все еще хороший вопрос, потому что я не смог найти ответ где угодно в Интернете. Вот так:

Можете ли вы показать пример использования register_module с пользовательским модулем? Единственные примеры, которые я нашел в Интернете, - это регистрация линейных или сверточных слоев в качестве подмодулей.

Я попытался написать свой собственный модуль и зарегистрировать его в другом модуле, и я не смог заставить его работать. Моя IDE говорит мне no instance of overloaded function "MyModel::register_module" matches the argument list -- argument types are: (const char [14], TreeEmbedding)

(TreeEmbedding - это имя другой созданной мной структуры, расширяющей torch :: nn :: Module.)

Я что-то упустил? Пример этого будет очень полезен.



Редактировать: дополнительный контекст следует ниже.

У меня есть файл заголовка "model.h" который содержит следующее:

struct TreeEmbedding : torch::nn::Module {
    TreeEmbedding();
    torch::Tensor forward(Graph tree);
};

struct MyModel : torch::nn::Module{
    size_t embeddingSize;
    TreeEmbedding treeEmbedding;

    MyModel(size_t embeddingSize=10);
    torch::Tensor forward(std::vector<Graph> clauses, std::vector<Graph> contexts);
};

У меня также есть cpp файл "модель. cpp", который содержит следующее:

MyModel::MyModel(size_t embeddingSize) :
    embeddingSize(embeddingSize)
{
    treeEmbedding = register_module("treeEmbedding", TreeEmbedding{});
}

Эта установка все еще имеет ту же ошибку как указано выше. Код в документации действительно работает (используя встроенные компоненты, такие как линейные слои), но использование пользовательского модуля - нет. После отслеживания факела :: nn :: Linear, это выглядит так, как будто это ModuleHolder (что бы это ни было ...)

Спасибо, Джек

1 Ответ

1 голос
/ 01 мая 2020

Я приму лучший ответ, если кто-нибудь сможет предоставить больше подробностей, но на всякий случай, если кому-то интересно, я подумал, что предоставлю небольшую информацию, которую смог найти:

register_module принимает строку как его первый аргумент и второй аргумент могут быть ModuleHolder (я не знаю, что это такое) или альтернативно shared_ptr для вашего модуля. Итак, вот мой пример:

treeEmbedding = register_module<TreeEmbedding>("treeEmbedding", make_shared<TreeEmbedding>());

Мне показалось, что пока это работает.

...