Как ускорить нарезку в python, не используя цикл for - PullRequest
2 голосов
/ 08 февраля 2020

Я пытаюсь ускорить следующий python код:

import torch
import numpy as np

A = torch.zeros(11, 16, 64)
B = torch.randn(11, 9, 64)

indices = np.random.randint(0,9,(11,16))

for i in range(len(A)):
    A[i,:,:] = B[i,indices[i],:]

Есть ли хороший способ не использовать для l oop? Таким образом, это очень медленно, особенно когда дело касается больших данных. Индексы - это предопределенная 2-мерная матрица с размером (11,16). Что мне нужно, это назначить элементы B для A в соответствии с порядком индексов. После ускорения результат A должен быть точно таким же, как и мой результат A. Спасибо!

Ответы [ 2 ]

1 голос
/ 08 февраля 2020

Вы можете использовать несколько многомерных индексов, но они должны быть одинакового размера или транслируемыми. Например,

# create a (11, 1) range array that broadcasts with indices which is (11, 16)
indices0 = np.expand_dims(np.arange(indices.shape[0]), 1)
A = B[indices0, indices, :]

Поскольку вещание может сбивать с толку, я попытаюсь объяснить это немного. По сути, вы хотите, чтобы indices0 и indices имели одинаковый размер и представляли собой пары индексов B. Первый индекс будет сохранен в indices0, а второй будет сохранен в indices в соответствующих местах. Вещание неявно повторяет столбцы indices0, чтобы придать ему ту же форму, что и indices, и часто может быть быстрее, чем создание полноразмерного indices0.

. В случае, если это поможет, приведем еще несколько подробных примеров, демонстрирующих почему это работает:

import torch
import numpy as np

B = torch.randn(11, 9, 64)
indices = np.random.randint(0,9,(11,16))

# constructing indices0 more verbosely (and slower) for demonstration purposes
a0, a1 = indices.shape
a2 = B.shape[2]

# construct a complete indices0 the slow way, the same size as indices
indices0 = np.empty((a0, a1), dtype=np.int32)
for i in range(a0):
    for j in range(a1):
        indices0[i,j] = i

# version 1 (nothing complicated happening here but very slow)
A1 = torch.empty(a0, a1, a2, dtype=B.dtype)
for i in range(a0):
    for j in range(a1):
        A1[i,j,:] = B[indices0[i,j], indices[i,j], :]

# version 2 (using advanced indexing without broadcasting)
A2 = B[indices0, indices, :]

# version 3 (with broadcasting)
# remove repeated columns leaving indices0 as (11, 1) the same state as above
indices0 = indices0[:, :1]
# broadcasting implicitly repeats columns of indices0 to match indices
A3 = B[indices0, indices, :]

# version 4 (your method)
A4 = torch.empty(a0, a1, a2, dtype=B.dtype)
for i in range(a0):
    A4[i,:,:] = B[i,indices[i],:]

# compare everything    
error = torch.sum(torch.abs(A1 - A2)).item() + \
        torch.sum(torch.abs(A2 - A3)).item() + \
        torch.sum(torch.abs(A3 - A4)).item()
print('Error:', error)

, который печатает

Error: 0.0

, демонстрируя, что все эти методы эквивалентны.


Кроме того, если вы хотите остаться в Платформа PyTorch и indices были torch.LongTensor вместо numpy.ndarray, тогда вы могли бы использовать

indices0 = torch.arange(indices.shape[0]).unsqueeze(1)
A = B[indices0, indices, :]
0 голосов
/ 08 февраля 2020

Нарезка с использованием numpy достаточно быстра даже для проектов машинного обучения. Если вы хотите, чтобы ваш код работал быстрее в этом случае, вы должны использовать это:

A_length = len(A)
i = 0
while i < A_length:
    A[i,:,:] = B[i,indices[i],:]
    i += 1

A range объект использует __iter__ и __next__ метод для генерации индекса итерации (в большинстве случаев ), даже если он записан в C, это медленнее, чем просто объявить счетчик индекса и добавлять шаг к нему каждый раунд.

Но for l oop более читабелен и прост для вашего кода плюс использование while l oop не сильно увеличит скорость. Я не думаю, что вы должны использовать while l oop для небольшого увеличения скорости.

Но ...

Если вы хотите, чтобы ваш код работал так быстро, как он мог:

  1. Изучите некоторые приемы производительности
  2. Рассмотрите возможность использования s sh и удаленного сервера GPU. Они намного быстрее вашего процессора (если вы используете компьютер с процессором)
  3. Learn C или JS, это скомпилированные языки. C примерно в 200 раз быстрее, чем python без оптимизации, и примерно в 50000 раз с многоядерной обработкой и многопоточностью (от здесь )
...