PyTorch использует эффективную реализацию BLAS и многопоточность (openMP, если я не ошибаюсь) для распараллеливания таких операций с несколькими ядрами.Некоторая потеря производительности происходит из-за самого Python - поскольку это интерпретируемый язык, никакой существенной оптимизации, подобной компилятору, не может быть сделано.Вы можете использовать модуль jit
для ускорения кода «обертки» вокруг умножений матриц, но для чего-то большего, чем очень маленькие матрицы, эта стоимость, вероятно, незначительна.
Одно большое улучшениеВы можете получить вручную, но то, что PyTorch не применяет автоматически, это для правильного порядка умножения матрицы.Как вы, вероятно, знаете, в зависимости от форм матрицы умножение ABCD
может иметь различную производительность, вычисляемую как A(B(CD))
, чем вычисление как (AB)(CD)
и т. Д.