Я использую тензорный поток для машинного обучения, и мне нужно заранее сделать некоторые преобразования в данных. Узким местом, с которым я сталкиваюсь, является моя функция - взять массив чисел, сравнить каждое число с любым другим числом и создать таблицу этих сравнений.
Функция следующая:
def compare(list, length):
result = np.zeros(length*length)
i=0
for row in range(length):
for col in range(length):
if row != col:
result[i] = list[col] - list[row]
else:
result[i] = list[row]
i = i + 1
return result.reshape((length,length))
Это мучительно медленно, есть ли способ использовать мой GPU для достижения того же результата? Или хотя бы оптимизировать эту функцию, чтобы она выполнялась быстрее?
Идея состоит в том, чтобы взять список чисел типа
0 1 2 3 4 5
и сгенерировать таблицу сравнения с разница между столбцом и строкой, за исключением случая, когда значение сравнивается с самим собой, и в этом случае оно просто возвращает себя.
[[ 0. 1. 2. 3. 4. 5.]
[-1. 1. 1. 2. 3. 4.]
[-2. -1. 2. 1. 2. 3.]
[-3. -2. -1. 3. 1. 2.]
[-4. -3. -2. -1. 4. 1.]
[-5. -4. -3. -2. -1. 5.]]
Я пытался использовать @jit, но результаты кажутся странно медленно
@jit(nopython=True, parallel=True)
def compare2(list, length):
result = np.zeros(length*length)
i=0
for row in range(length):
for col in range(length):
if row != col:
result[i] = list[col] - list[row]
else:
result[i] = list[row]
i = i + 1
return result.reshape((length,length))
дает следующий результат в списке из ~ 300 элементов:
compare: Elapsed time is 0.062655 seconds.
compare2: Elapsed time is 0.423727 seconds.
РЕДАКТИРОВАТЬ:
Я использую функция сравнения внутри другого l oop для генерации списка таблиц. Я объединил их в одну функцию, и теперь @jit превосходит исходную функцию в 15 раз.