Как векторизовать следующий код Python - PullRequest
0 голосов
/ 23 ноября 2018

Я пытаюсь получить матрицу, в которой каждый элемент рассчитывается следующим образом:

X = torch.ones(batch_size, dim)
X_ = torch.ones(batch_size, dim)
Y = torch.ones(batch_size, dim)
M = torch.zeros(batch_size, batch_size)
for i in range(batch_size):
    for j in range(batch_size):
        M[i, j] = ((X[i] - X_[i] * Y[j])**2).sum()

Очень медленно вычислять M по элементам, есть ли какие-либо предложения о том, как использовать матрицуумножение для замены циклов?

Спасибо.

1 Ответ

0 голосов
/ 23 ноября 2018

Если вы хотите sum() сверх темного, вы можете «поднять» свою 2D задачу на 3D и суммировать там:

M = ((X[:, None, :] - X_[:, None, :] * Y[None, ...])**2).sum(dim=2)

Как это работает :

X[:, None, :] и X_[:, None, :] являются 3D размером (batch_size, 1, dim), а Y[None, ...] имеет размер (1, batch_size, dim).

При умножении X_[:, None, :] * Y[None, ...] pytorch передает измерения размера 1 на соответствующее измерение, чтобы получить результат размера (batch_size, batch_size, dim).
Наконец, вы sum() только надпоследнее измерение (dim=2) для получения M размера (batch_size, batch_size).

Хитрость здесь заключается в использовании вещания .

...