BLUF: используя полную функциональность Numpy, плюс еще один аккуратный модуль, вы можете получить код Python более чем в 100 раз быстрее, чем этот необработанный код цикла for. Однако, используя ответ @ max9111, вы можете получить еще быстрее с гораздо более чистым кодом и меньшим количеством работы.
Полученный код не похож на оригинальный, поэтому я буду выполнять оптимизацию по одному шагу за раз, чтобы процесс и окончательный код имели смысл. По сути, мы собираемся использовать большое количество трансляций , чтобы заставить Numpy выполнять зацикливание (что всегда быстрее, чем зацикливание в Python). Результат вычисляет полный квадрат результатов, что означает, что мы обязательно дублируем некоторую работу, поскольку результат симметричен, но проще и, честно говоря, возможно быстрее выполнить эту работу высокопроизводительными способами, чем иметь if
при самый глубокий уровень зацикливания, чтобы избежать вычислений. Этого можно избежать в Фортране, но, вероятно, не в Python. Если вы хотите, чтобы результат был идентичен предоставленному источнику, нам нужно взять верхний треугольник результата моего кода ниже (что я делаю в примере кода ниже ... не стесняйтесь удалить вызов triu
в реальном производстве, это не обязательно).
Во-первых, мы заметим несколько вещей. Основное уравнение имеет знаменатель, который выполняет np.sqrt
, но содержание этого вычисления не изменяется на любой итерации цикла, поэтому мы вычислим его один раз и повторно используем результат. Это оказывается незначительным, но мы все равно это сделаем. Далее, основная функция двух внутренних циклов заключается в выполнении eigv[k1][j1] - eigv[k1][i1]
, что довольно легко векторизовать. Если eigv
является матрицей, то eigv[k1] - eigv[k1].T
создает матрицу, где result[i1, j1] = eigv[k1][j1] - eigv[k1][i1]
. Это позволяет нам полностью удалить две самые внутренние петли:
def mine_Delta_Gaussf(Nw, N_bd, N_kp, hw, width, eigv):
Delta_Gauss = np.zeros((Nw, N_kp, N_bd, N_bd), dtype=float)
denom = np.sqrt(2.0 * np.pi) * width
eigv = np.matrix(eigv)
for w1 in range(Nw):
for k1 in range(N_kp):
this_eigv = (eigv[k1] - eigv[k1].T - hw[w1])
v = np.power(this_eigv / width, 2)
Delta_Gauss[w1, k1, :, :] = np.exp(-0.5 * v) / denom
# Take the upper triangle to make the result exactly equal to the original code
return np.triu(Delta_Gauss)
Что ж, теперь, когда мы находимся в эфире, кажется, что оставшиеся две петли можно удалить таким же образом. Как это бывает, это легко! Единственное, для чего нам нужно k1
- это получить строку из eigv
, которую мы пытаемся попарно вычесть ... так почему бы не сделать это для всех строк одновременно? В настоящее время мы в основном вычитаем матрицы форм (1, B) - (B, 1)
для каждой из N
строк в eigv
(где B
равно N_bd
). Мы можем злоупотреблять трансляцией, чтобы сделать это для всех строк eigv
одновременно, вычитая матрицы форм (N, 1, B) - (N, B, 1)
(где N
равно N_kp
):
def mine_Delta_Gaussf(Nw, N_bd, N_kp, hw, width, eigv):
Delta_Gauss = np.zeros((Nw, N_kp, N_bd, N_bd), dtype=float)
denom = np.sqrt(2.0 * np.pi) * width
for w1 in range(Nw):
this_eigv = np.expand_dims(eigv, 1) - np.expand_dims(eigv, 2) - hw[w1]
v = np.power(this_eigv / width, 2)
Delta_Gauss[w1, :, :, :] = np.exp(-0.5 * v) / denom
return np.triu(Delta_Gauss)
Следующий шаг должен быть понятен сейчас. Мы используем w1
только для индексации hw
, так что давайте сделаем еще несколько трансляций, чтобы numpy
зациклился. В настоящее время мы вычитаем скалярное значение из матрицы формы (N, B, B)
, поэтому, чтобы получить результирующую матрицу для каждого из W
значений в hw
, нам необходимо выполнить вычитание для матриц форм (1, N, B, B) - (W, 1, 1, 1)
и numpy
будет транслировать все, чтобы получить матрицу формы (W, N, B, B)
:
def Delta_Gaussf(hw, width, eigv):
eigv_sub = np.expand_dims(eigv, 1) - np.expand_dims(eigv, 2)
w_sub = np.expand_dims(eigv_sub, 0) - np.reshape(hw, (0, 1, 1, 1))
v = np.power(w_sub / width, 2)
denom = np.sqrt(2.0 * np.pi) * width
Delta_Gauss = np.exp(-0.5 * v) / denom
return np.triu(Delta_Gauss)
На моем примере данных этот код работает в ~ 100 раз быстрее (от ~ 900 мс до ~ 10 мс). Ваш пробег может отличаться.
Но подождите! Есть еще кое-что! Поскольку весь наш код числовой / numpy / python, мы можем использовать другой удобный модуль с именем numba
, чтобы скомпилировать эту функцию в эквивалентный с более высокой производительностью. В сущности, это в основном чтение того, какие функции мы вызываем, и преобразование функции в C-типы и C-вызовы для устранения накладных расходов при вызове функций Python. Это делает больше, чем это, но это дает представление о том, где мы собираемся получить выгоду. Получить это преимущество тривиально в этом случае:
import numba
@numba.jit
def Delta_Gaussf(hw, width, eigv):
eigv_sub = np.expand_dims(eigv, 1) - np.expand_dims(eigv, 2)
w_sub = np.expand_dims(eigv_sub, 0) - np.reshape(hw, (0, 1, 1, 1))
v = np.power(w_sub / width, 2)
denom = np.sqrt(2.0 * np.pi) * width
Delta_Gauss = np.exp(-0.5 * v) / denom
return np.triu(Delta_Gauss)
Полученная функция сократилась до ~ 7 мс по моим образцам данных, по сравнению с ~ 10 мс, просто добавив этот декоратор. Довольно хорошо, без усилий.
EDIT: @ max9111 дал лучший ответ, который указывает, что numba
работает намного лучше с синтаксисом цикла, чем с numpy
широковещательным кодом. Почти без работы, кроме удаления внутреннего оператора if
, он показывает, что numba.jit
можно сделать, чтобы получить почти оригинальный код еще быстрее. Результат намного чище, потому что у вас все еще есть только одно внутреннее уравнение, которое показывает, каково каждое значение, и вам не нужно следовать волшебному вещанию, используемому выше. Я настоятельно рекомендую использовать его ответ.
Заключение
Для моих данных примера (Nw = 20, N_bd = 20, N_kp = 20) мои окончательные значения времени выполнения следующие (я включил тайминги на том же компьютере для решения @ max9111, сначала без параллельного выполнения, а затем с ним на моей 2-ядерной виртуальной машине):
Original code: ~900 ms
Fortran estimate: ~90 ms (based on OP saying it was ~10x faster)
Final numpy code: ~10 ms
Final code with numba.jit: ~7 ms
max9111's solution (serial): ~4ms
max9111 (parallel 2-core): ~3ms
Overall vectorized speedup: ~130x
max9111's numba speedup: ~300x (potentially more with more cores)
Я не знаю, насколько точен ваш код на Фортране, но похоже, что правильное использование numpy
позволяет легко превзойти его на порядок, а решение * max6911 от @ max9111 дает вам потенциально другой порядок величины.