python: Autograd возвращает нулевой градиент после функции векторизации - PullRequest
0 голосов
/ 18 октября 2019

Этот вопрос имеет длинную историю:

У меня есть сложная функция, градиент которой я хочу вычислить, и это слишком быстро. Я решил использовать autograd, который работал очень хорошо. Однако мне нужно было ускорить его, и поэтому я решил использовать функцию autograd в jax, которая может использовать ускорение GPU.

Однако jax, возвращая тот же ответ, что и autograd, не смог ускорить операцию и даже иногда приводил к сбою ядра. Через форум jax я узнал, что проблема в том, что моя функция не была векторизована, то есть она использовала циклы for вместо векторных операций для перебора массивов (причина этого в том, что я изначально написал функцию дляcython, что требуется для циклов). Итак, я векторизовал свою функцию (и ее вспомогательные функции-обертки) и убедился, что не допустил ошибок в процессе, убедившись, что функция вернула одинаковые значения.

Однако с этой векторизованной функцией оба autograd и jax возвращают нулевой градиент. Я предполагаю, что где-то есть небольшая ошибка или несоответствие, возможно, в функциях-оболочках, но я не могу понять, что именно.

Я опубликовал полный рабочий пример в Google colab с версиями моей функции как векторизованной, так и не векторизованной: https://colab.research.google.com/drive/1VCPiRTLfxOflooDtTlqHI4VzYQkXzaUn

, а также опубликовал проблему на форумах jax, но безуспешно: https://github.com/google/jax/issues/1407

Может кто-нибудь помочь, пожалуйста?

...