Pytorch Конкатенация строк в альтернативном порядке - PullRequest
1 голос
/ 04 апреля 2020

Я пытаюсь кодировать позиционную кодировку на бумаге трансформаторов. Для этого мне нужно выполнить операцию, аналогичную следующей:

a = torch.arange(20).reshape(4,5) 
b = a * 2
c = torch.cat([torch.stack([a_row,b_row]) for a_row, b_row in zip(a,b)])

Мне кажется, что может быть более быстрый способ сделать это? возможно, добавив измерение к a и b?

Ответы [ 2 ]

2 голосов
/ 04 апреля 2020

Я бы просто использовал для этого оператор присваивания:

c = torch.zeros(8, 5)
c[::2, :] = a   # Index every second row, starting from 0
c[1::2, :] = b  # Index every second row, starting from 1 

При расчете времени для двух решений я использовал следующее:

import timeit
import torch
a = torch.arange(20).reshape(4,5) 
b = a * 2

suggested = timeit.timeit("c = torch.cat([torch.stack([a_row, b_row]) for a_row, b_row in zip (a, b)])", 
                          setup="import torch; from __main__ import a, b", number=10000)
print(suggested/10000)
# 4.5105120493099096e-05

improved = timeit.timeit("c = torch.zeros(8, 5); c[::2, :] = a; c[1::2, :] = b", 
                         setup="import torch; from __main__ import a, b", number=10000)
print(improved/10000)
# 2.1489459509029985e-05

Второй подход требует значительно меньше (приблизительно половину) времени, хотя одна итерация все еще очень быстра. Конечно, вам придется проверить это для ваших реальных тензорных размеров, но это самое простое решение, которое я мог придумать. Не могу дождаться, чтобы увидеть, есть ли у кого-нибудь изящное низкоуровневое решение для этого, которое еще быстрее!

Кроме того, имейте в виду, что я не рассчитывал время создания b, предполагая, что тензоры вы хочу, чтобы переплетение уже дано.

1 голос
/ 06 апреля 2020

Так получается, что простая конкатенация и изменение формы делают свое дело:

c = torch.cat([a, b], dim=-1).view(-1, a.shape[-1])

Когда я рассчитал это со следующим, он был примерно в 2,3 раза быстрее, чем ответ @ dennlinger:

improved2 = timeit.timeit("c = torch.cat([a, b], dim=-1).view(-1, a.shape[-1])",  
                          setup="import torch; from __main__ import a, b", 
                          number=10000) 
print(improved2/10000) 
# 7.253780400003507e-06
print(improved / improved2)
# 2.3988091506044955
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...