Отказ от ответственности: я не профилировал этот код, чтобы посмотреть, действительно ли он быстрее на GPU.
Одно векторизованное решение - использовать тензорные представления для трансляции сравнение. Тензорные представления не используют дополнительную память. Вы можете увидеть более подробную информацию в документации
Сначала создайте матрицу, содержащую значения, которые вы хотите сравнить для каждой строки. В данном случае это просто индексы строк.
comparison = torch.tensor(range(max_idx))
Теперь мы будем использовать expand
и unsqueeze
, чтобы сделать виды data_idx
и comparison
одинаковой формы. как filled_matrix
.
comparison_view = comparison.unsqueeze(1).expand(max_idx, max_number_data_idx)
print(comparison_view)
# Each row is the index you want to compare to
# tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
[5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5]])
data_idx_view = data_idx.expand(max_idx, max_number_data_idx)
print(data_idx_view)
# Each row is a copy of data_idx
# tensor([[2, 5, 5, 0, 4, 1, 4, 5, 3, 2, 1, 0, 3, 3, 0],
[2, 5, 5, 0, 4, 1, 4, 5, 3, 2, 1, 0, 3, 3, 0],
[2, 5, 5, 0, 4, 1, 4, 5, 3, 2, 1, 0, 3, 3, 0],
[2, 5, 5, 0, 4, 1, 4, 5, 3, 2, 1, 0, 3, 3, 0],
[2, 5, 5, 0, 4, 1, 4, 5, 3, 2, 1, 0, 3, 3, 0],
[2, 5, 5, 0, 4, 1, 4, 5, 3, 2, 1, 0, 3, 3, 0]])
Мы можем сравнить их равенство и использовать nonzero
, чтобы найти индексы
mask = comparison_view == data_idx_view
mask_indices = mask.nonzero()
print(mask_indices)
# tensor([[ 0, 3],
[ 0, 11],
[ 0, 14],
[ 1, 5],
[ 1, 10],
[ 2, 0],
[ 2, 9],
[ 3, 8],
[ 3, 12],
[ 3, 13],
[ 4, 4],
[ 4, 6],
[ 5, 1],
[ 5, 2],
[ 5, 7]])
Теперь вам просто нужно манипулировать эти результаты в формате, который вы хотите для вывода.
filled_matrix = torch.zeros([max_idx, max_number_data_idx], dtype=torch.int8)
filled_matrix.fill_(-1)
col_indices = [0, 1, 2, 0, 1, 0, 1, 0, 1, 2, 0, 1, 0, 1, 2]
filled_matrix[mask_indices[:, 0], col_indices] = mask_indices[:, 1].type(torch.int8)
Я думал о нескольких вариантах создания списка col_indices
, но ничего не смог придумать без for l oop.
col_indices = torch.zeros(mask_indices.shape[0])
for i in range(1, mask_indices.shape[0]):
if mask_indices[i,0] == mask_indices[i-1,0]:
col_indices[i] = col_indices[i-1]+1
Вам нужно будет выполнить профилирование, чтобы увидеть, какой код на самом деле быстрее.