PyTorch эквивалент index_add_, который принимает максимум вместо - PullRequest
0 голосов
/ 30 мая 2018

В PyTorch метод Tensor index_add_ выполняет суммирование с использованием предоставленного индексного тензора:

idx = torch.LongTensor([0,0,0,0,1,1])
child = torch.FloatTensor([1, 3, 5, 10, 8, 1])
parent = torch.FloatTensor([0, 0])
parent.index_add_(0, idx, child)

Первые четыре дочерних значения суммируются в parent [0], а следующие два переходят в parent[1], поэтому результат равен tensor([ 19., 9.])

Однако вместо этого мне нужно сделать index_max_, чего нет в API.Есть ли способ сделать это эффективно (без необходимости зацикливать или выделять больше памяти)?Одно (плохое) решение для цикла было бы:

for i in range(max(idx)+1):
    parent[i] = torch.max(child[idx == i])

Это дает желаемый результат tensor([ 10., 8.]), но очень медленно.

1 Ответ

0 голосов
/ 30 мая 2018

Решение, играющее с индексами:

def index_max(child, idx, num_partitions): 
    # Building a num_partition x num_samples matrix `idx_tiled`:
    partition_idx = torch.range(0, num_partitions - 1, dtype=torch.long)
    partition_idx = partition_idx.view(-1, 1).expand(num_partitions, idx.shape[0])
    idx_tiled = idx.view(1, -1).repeat(num_partitions, 1)
    idx_tiled = (idx_tiled == partition_idx).float()
    # i.e. idx_tiled[i,j] == 1 if idx[j] == i, else 0

    parent = idx_tiled * child
    parent, _ = torch.max(parent, dim=1)
    return parent

Бенчмаркинг:

import timeit

setup = '''
import torch

def index_max_v0(child, idx, num_partitions):
    parent = torch.zeros(num_partitions)
    for i in range(max(idx) + 1):
        parent[i] = torch.max(child[idx == i])
    return parent

def index_max(child, idx, num_partitions):

    # Building a num_partition x num_samples matrix `idx_tiled` 
    # containing for each row indices of
    partition_idx = torch.range(0, num_partitions - 1, dtype=torch.long)
    partition_idx = partition_idx.view(-1, 1).expand(num_partitions, idx.shape[0])
    idx_tiled = idx.view(1, -1).repeat(num_partitions, 1)
    idx_tiled = (idx_tiled == partition_idx).float()

    parent = idx_tiled * child
    parent, _ = torch.max(parent, dim=1)
    return parent

idx = torch.LongTensor([0,0,0,0,1,1])
child = torch.FloatTensor([1, 3, 5, 10, 8, 1])
num_partitions = torch.unique(idx).shape[0]

'''
print(min(timeit.Timer('index_max_v0(child, idx, num_partitions)', setup=setup).repeat(5, 1000)))
# > 0.05308796599274501
print(min(timeit.Timer('index_max(child, idx, num_partitions)', setup=setup).repeat(5, 1000)))
# > 0.024736385996220633
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...