Как добавить параметры из self.some_dictionary_of_modules в self.parameters? - PullRequest
0 голосов
/ 19 марта 2019

Рассмотрим простой пример:

import torch

class MyModule(torch.nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.conv_0=torch.nn.Conv2d(3,32,3,stride=1,padding=0)
        self.blocks=torch.nn.ModuleList([
            torch.nn.Conv2d(3,32,3,stride=1,padding=0),
            torch.nn.Conv2d(32,64,3,stride=1,padding=0)])

        #the problematic part
        self.dict_block={"key_1": torch.nn.Conv2d(64,128,3,1,0),
                "key_2": torch.nn.Conv2d(56,1024,3,1,0)}

if __name__=="__main__":
    my_module=MyModule()
    print(my_module.parameters)

Вывод, который я здесь получаю (обратите внимание, что параметры из self.dict_block отсутствуют)

<bound method Module.parameters of MyModule(
  (conv_0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
  (blocks): ModuleList(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  )
)>

Это означает, что если я хочуДля оптимизации параметров self.dict_block мне нужно будет использовать что-то вроде

my_optimiser.add_param_group({"params": params_from_self_dict})

, прежде чем использовать мой оптимизатор.Тем не менее, я думаю, что может быть более прямолинейная альтернатива, которая добавит параметры self.dict_block к параметрам my_module_object.Что-то, что приближается, это nn.Parameter(...), как объяснено здесь , но для этого требуется, чтобы входное значение было тензорным.

1 Ответ

1 голос
/ 19 марта 2019

Нашел ответ.Отправка сообщения на случай, если кто-то столкнется с той же проблемой:

Если посмотреть на <torch_install>/torch/nn/modules/container.py, то, похоже, существует класс torch.nn.ModuleDict, который просто делает это.Таким образом, в примере, который я привел в вопросе, решение будет:

self.dict_block=torch.nn.ModuleDict({"key_1": torch.nn.Conv2d(64,128,3,1,0),
            "key_2": torch.nn.Conv2d(56,1024,3,1,0)})
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...