Я бы просто использовал для этого оператор присваивания:
c = torch.zeros(8, 5)
c[::2, :] = a # Index every second row, starting from 0
c[1::2, :] = b # Index every second row, starting from 1
При расчете времени для двух решений я использовал следующее:
import timeit
import torch
a = torch.arange(20).reshape(4,5)
b = a * 2
suggested = timeit.timeit("c = torch.cat([torch.stack([a_row, b_row]) for a_row, b_row in zip (a, b)])",
setup="import torch; from __main__ import a, b", number=10000)
print(suggested/10000)
# 4.5105120493099096e-05
improved = timeit.timeit("c = torch.zeros(8, 5); c[::2, :] = a; c[1::2, :] = b",
setup="import torch; from __main__ import a, b", number=10000)
print(improved/10000)
# 2.1489459509029985e-05
Второй подход требует значительно меньше (приблизительно половину) времени, хотя одна итерация все еще очень быстра. Конечно, вам придется проверить это для ваших реальных тензорных размеров, но это самое простое решение, которое я мог придумать. Не могу дождаться, чтобы увидеть, есть ли у кого-нибудь изящное низкоуровневое решение для этого, которое еще быстрее!
Кроме того, имейте в виду, что я не рассчитывал время создания b
, предполагая, что тензоры вы хочу, чтобы переплетение уже дано.