Этот вопрос имеет длинную историю:
У меня есть сложная функция, градиент которой я хочу вычислить, и это слишком быстро. Я решил использовать autograd
, который работал очень хорошо. Однако мне нужно было ускорить его, и поэтому я решил использовать функцию autograd
в jax
, которая может использовать ускорение GPU.
Однако jax
, возвращая тот же ответ, что и autograd
, не смог ускорить операцию и даже иногда приводил к сбою ядра. Через форум jax
я узнал, что проблема в том, что моя функция не была векторизована, то есть она использовала циклы for
вместо векторных операций для перебора массивов (причина этого в том, что я изначально написал функцию дляcython
, что требуется для циклов). Итак, я векторизовал свою функцию (и ее вспомогательные функции-обертки) и убедился, что не допустил ошибок в процессе, убедившись, что функция вернула одинаковые значения.
Однако с этой векторизованной функцией оба autograd
и jax
возвращают нулевой градиент. Я предполагаю, что где-то есть небольшая ошибка или несоответствие, возможно, в функциях-оболочках, но я не могу понять, что именно.
Я опубликовал полный рабочий пример в Google colab с версиями моей функции как векторизованной, так и не векторизованной: https://colab.research.google.com/drive/1VCPiRTLfxOflooDtTlqHI4VzYQkXzaUn
, а также опубликовал проблему на форумах jax
, но безуспешно: https://github.com/google/jax/issues/1407
Может кто-нибудь помочь, пожалуйста?