Давайте упростим задачу, отбросив одно измерение и используя значения, которые легко проверить вручную:
W = np.zeros(3, np.int)
y = np.array([0, 1, 1, 2, 2])
X = np.array([1, 2, 3, 4, 5])
Значения в векторе W
получают дополнительные значения из X
, просматривая y
:
for i, j in enumerate(y):
W[j] += X[i]
W
рассчитывается как [1, 5, 9]
, (проверьте быстро вручную).
Теперь, как этот код может быть векторизован?Мы не можем сделать простой W[y] += X[y]
, так как y
содержит повторяющиеся значения, и разные суммы будут перезаписывать друг друга с индексами 1 и 2.
Что можно сделать, это передать значения вновое измерение len(y)
, а затем суммируйте по этому вновь созданному измерению.
N = W.shape[0]
select = (np.arange(N) == y[:, None]).astype(np.int)
Принимая диапазон индекса W
([0, 1, 2]
) и устанавливая значения, для которых они соответствуют y
, в1 в новом измерении, иначе 0. select
содержит этот массив:
array([[1, 0, 0],
[0, 1, 0],
[0, 1, 0],
[0, 0, 1],
[0, 0, 1]])
Он имеет len(y) == len(X)
строк и len(W)
столбцов и показывает для каждого y / ряда, какой индекс W
Давайте умножим X с этим массивом, mult = select * X[:, None]
:
array([[1, 0, 0],
[0, 2, 0],
[0, 3, 0],
[0, 0, 4],
[0, 0, 5]])
Мы эффективно разложили X в новое измерение и отсортировали его таким образом, чтобы получитьэто в форму W путем суммирования по вновь созданному измерению.Сумма по строкам - это вектор, который мы хотим добавить к W
:
sum_Xy = np.sum(mult, axis=0) # [1, 5, 9]
W += sum_Xy
Вычисление select
и mult
можно комбинировать с np.einsum
:
# `select` has shape (len(y)==len(X), len(W)), or `yw`
# `X` has shape len(X)==len(y), or `y`
# we want something `len(W)`, or `w`, and to reduce the other dimension
sum_Xy = np.einsum("yw,y->w", select, X)
И это все для одномерного примера.Для двумерной задачи, поставленной в вопросе, это точно такой же подход: ввести дополнительное измерение, передать индексы y
, а затем уменьшить дополнительное измерение с помощью einsum
.
Если вы усвоите, каккаждый шаг работает для одномерного примера, я уверен, что вы можете понять, как код делает это в двух измерениях, так как это просто вопрос правильного определения индексов (W строк, X столбцов).