Jax Vmap: обеспечить правильную форму - PullRequest
0 голосов
/ 28 октября 2019

Я использую vmap для векторизации частей моего кода. Вот минимальный пример до векторизации:

dim = 2
def sum(x):
    a = np.ones((dim,))
    return np.dot(x, a)

num_samples = 100
samples = np.ones((num_samples, dim))

sum(samples[0]) # 2

с vmap:

sum = vmap(sum)
sum(samples) # DeviceArray of shape (100,), all entries are 2

Но после векторизации это может пойти не так:

sum(samples[0]) # DeviceArray of shape (2,2), all entries are 1

Чтоздесь происходит то, что samples[0] имеет форму (2,). Векторизованный вызов функции разбивает свой входной аргумент вдоль первой оси и, следовательно, получает 2 массива формы (1,). Из-за широковещательной передачи с a результирующий вывод снова имеет форму (2,) и укладывается в массив (2,2).

Это кажется мне опасным. Код выглядит нормально, и полученный результат может быть легко использован другими правилами вещания, которые скрывают его неправильную форму.

Можно ли придать правильную форму?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...