Я использую vmap для векторизации частей моего кода. Вот минимальный пример до векторизации:
dim = 2
def sum(x):
a = np.ones((dim,))
return np.dot(x, a)
num_samples = 100
samples = np.ones((num_samples, dim))
sum(samples[0]) # 2
с vmap:
sum = vmap(sum)
sum(samples) # DeviceArray of shape (100,), all entries are 2
Но после векторизации это может пойти не так:
sum(samples[0]) # DeviceArray of shape (2,2), all entries are 1
Чтоздесь происходит то, что samples[0]
имеет форму (2,)
. Векторизованный вызов функции разбивает свой входной аргумент вдоль первой оси и, следовательно, получает 2 массива формы (1,)
. Из-за широковещательной передачи с a
результирующий вывод снова имеет форму (2,)
и укладывается в массив (2,2)
.
Это кажется мне опасным. Код выглядит нормально, и полученный результат может быть легко использован другими правилами вещания, которые скрывают его неправильную форму.
Можно ли придать правильную форму?