PyTorch: эффективно чередовать два тензора в произвольном порядке - PullRequest
1 голос
/ 29 февраля 2020

Я хочу создать новый тензор z из двух тензоров, скажем x и y с размерами [N_samples, S, N_feats] и [N_samples, T, N_feats] соответственно. Цель состоит в том, чтобы объединить оба тензора во втором диме, смешивая элементы 2-го дим в определенном порядке c, который хранится в переменной order с dim [N_samples, U].

Порядок отличается для каждого образца и, в основном, зависит от того, какой индекс извлечь из какого тензора. Это выглядит так для данного образца order[0] - [x_0, x_1, y_0, x_2, y_1, ... ], где буква указывает тензор, а число указывает индекс 2-го затемнения. Так что z[0] будет

z[0] = [x[0, 0, :], x[0, 1, :], y[0, 0, :], x[0, 2, :], y[0, 1, :] ... ]

Как бы мне этого добиться? Я написал что-то, что использует torch.gather, который пытается это сделать.

x = torch.rand((2, 4, 5))
y = torch.rand((2, 3, 5))

# new ordering of second dim
# positive means take (n-1)th element from x
# negative means take (n-1)th element from y
order = [[1, 2, -1, 3, -2, 4, 3], 
         [1, -1, -2, 2, 3, 4, -3]]

# simple concat for gather
combined = torch.cat([x, y], dim=1)

# add a zero padding on top of combined tensor to ease gather
zero = torch.zeros_like(x)[:, 1:2] 
combined = torch.cat([zero, combined], dim=1)

def _create_index_for_gather(index, offset, n_feats):
    new_index = [abs(i) + offset if i < 0 else i for i in index]

    # need to repeat index for each dim for torch.gather
    new_index = [[x] * n_feats for x in new_index]
    return new_index

_, offset, n_feats = x.shape
index_for_gather = [_create_index_for_gather(i, offset, n_feats) for i in order]

z = combined.gather(dim=1, index=torch.tensor(index_for_gather))

Есть ли более эффективный способ сделать это?

...