Как найти исходный код встроенной функции в pytorch - PullRequest
0 голосов
/ 02 октября 2019

Я пытаюсь провести исследование по нормализации партии, и мне пришлось внести некоторые изменения в код BN pytorch. Я копаюсь в коде pytorch и застреваю с torch.nn.functional.batch_norm, который ссылается на torch.batch_norm.

Проблема в том, что torch.batch_norm не может быть далее найден в библиотеке факелов. Есть ли способ найти исходный код этой встроенной функции и повторно реализовать его? Спасибо!

1 Ответ

3 голосов
/ 02 октября 2019

Он есть, но не определен в Python. Они определены в C ++ в каталогах aten/.

Для ЦП реализация (одна из них зависит от того, является ли вход непрерывным) находится здесь: https://github.com/pytorch/pytorch/blob/420b37f3c67950ed93cd8aa7a12e673fcfc5567b/aten/src/ATen/native/Normalization.cpp#L61-L126

Для CUDA реализация здесь: https://github.com/pytorch/pytorch/blob/7aae51cdedcbf0df5a7a8bf50a947237ac4b3ee8/aten/src/ATen/native/cudnn/BatchNorm.cpp#L52-L143

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...