PyTorch - это высокоуровневая программная библиотека с множеством оболочек python для высоко оптимизированного скомпилированного кода. Функция или оператор поддерживает пакетные данные или нет. Нет другого пути, кроме как написать собственный код C / C ++ / CUDA и вызвать его с помощью python.
К счастью, большинство функций поддерживают пакетную обработку (включая torch.svd()
, как указано jodag ) и можно предположить, что разработчики (или компилятор) обратили внимание на параллелизм данных при реализации. Я рекомендую вам складывать свои тензоры везде, где вы можете. Обычно это приводит к значительному ускорению.
Обратите внимание, что размер партии всегда является первым размером тензора. PyTorch поддерживает вещание для обычных операторов, таких как +, -, *, /
, как описано здесь . Из-за возможных двусмысленностей вам иногда необходимо соответствующим образом изменить свои данные, чтобы прояснить, что вы хотите. Например, если вы хотите добавить пакет скаляров в пакет векторов, вам нужно сделать что-то вроде:
a = torch.zeros(2, 2)
b = torch.arange(2)
a + b.view(2, 1) # or b.reshape(2, 1)
# tensor([[0., 0.],
[1., 1.]])