Есть ли более быстрый способ сдвинуть списки «мимо» друг друга? - PullRequest
0 голосов
/ 23 января 2019

Мой код берет два отдельных списка A и B и сдвигает один «мимо» другого (Изобразите поезд, проходящий мимо припаркованного поезда), взяв квадрат точечного произведения последовательностей в каждую смену.

Я не знаю многих уловок в Python, поэтому, возможно, мое решение довольно неуклюже.

for shift in range (1,len(B)):
   total += np.dot( A[-shift:] , B[:shift] )**2 + np.dot( A[:shift] , B[-shift:] )**2
total += (np.dot(A,B))**2

Он работает как есть, но сейчас я работаю с такими массивными наборами данных, что скорость становится серьезной проблемой.

1 Ответ

0 голосов
/ 23 января 2019

При создании фрагментов в python (например, A[-shift:]) вы создаете копию массива. Например, вы можете увидеть это, выполнив:

A=[1,2,3]
B=A[:1]
B[0]=17
A // => A is still [1,2,3] not [17,2,3]

Однако, чтобы избежать копирования, вы можете использовать массивы numpy:

A=numpy.array([1,2,3])
B=A[:1]
B[0]=17
A /// => A is numpy.array([17,  2,  3])

Так что, если вы используете массивные массивы, копирование данных будет намного меньше, и я подозреваю, что ваш код будет более эффективным. Но, как всегда; сравните это, чтобы подтвердить это.

См. https://stackoverflow.com/a/5131563/922613 или https://scipy -cookbook.readthedocs.io / items / ViewsVsCopies.html для получения дополнительной информации

Я проверил это с помощью следующего скрипта:

import numpy as np

def normal_arrays():
    A=[1,2,3,4]
    B=[1,2,3,4]
    total = 0
    for shift in range (1,len(B)):
        total += np.dot( A[-shift:] , B[:shift] )**2 + np.dot( A[:shift] , B[-shift:] )**2
        total += (np.dot(A,B))**2


def numpy_arrays():
    A=np.array([1,2,3,4])
    B=np.array([1,2,3,4])
    total = 0
    for shift in range (1,len(B)):
        total += np.dot( A[-shift:] , B[:shift] )**2 + np.dot( A[:shift] , B[-shift:] )**2
        total += (np.dot(A,B))**2


if __name__ == "__main__":
    import timeit
    print('normal arrays', timeit.timeit(normal_arrays))
    print('numpy arrays', timeit.timeit(numpy_arrays))

Мои результаты показали улучшение на 50% во время выполнения:

normal arrays 21.76756980700884
numpy arrays 11.998605689994292
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...