Вывод структуры сети Pytorch - PullRequest
       73

Вывод структуры сети Pytorch

2 голосов
/ 06 октября 2019

Для моего случая использования мне нужно иметь возможность взять модуль pytorch и интерпретировать последовательность слоев в модуле, чтобы я мог создать «соединение» между слоями в некотором формате файла. Теперь предположим, что у меня есть простой модуль, как показано ниже:

class mymodel(nn.Module):
    def __init__(self, input_channels):
        super(mymodel, self).__init__()
        self.fc = nn.Linear(input_channels, input_channels)
    def forward(self, x):
        out = self.fc(x)
        out += x
        return out


if __name__ == "__main__":
    net = mymodel(5)

    for mod in net.modules():
        print(mod) 

Здесь вывод дает:

mymodel(
  (fc): Linear(in_features=5, out_features=5, bias=True)
)
Linear(in_features=5, out_features=5, bias=True)

, так как вы можете видеть информацию об операции плюс равно или операция плюс не захваченатак как это не nnmodule в функции forward. Моя цель - создать графическое соединение из объекта модуля pytorch и сказать что-то вроде этого в json:

layers {
"fc": {
"inputTensor" : "t0",
"outputTensor": "t1"
}
"addOp" : {
"inputTensor" : "t1",
"outputTensor" : "t2"
}
}

Имена входных тензоров произвольны, но они отражают суть графа и соединений. между слоями.

У меня вопрос, есть ли способ извлечь информацию из объекта pytorch? Я думал об использовании .modules (), но потом понял, что рукописные операции не фиксируются таким образом как модуль. Я думаю, что если все это nn.module, то .modules () может дать мне расположение сетевого уровня. Нужна помощь здесь. Я хочу иметь возможность знать связи между тензорами, чтобы создать формат, как указано выше.

1 Ответ

2 голосов
/ 06 октября 2019

Информация, которую вы ищете, хранится не в nn.Module, а в атрибуте grad_fn выходного тензора:

model = mymodel(channels)
pred = model(torch.rand((1, channels))
pred.grad_fn  # all the information is in the computation graph of the output tensor

Извлекать эту информацию нетривиально. Возможно, вы захотите взглянуть на пакет torchviz , который рисует хороший график из информации grad_fn.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...