Как мне обновить тензор в Pytorch после индексации дважды? - PullRequest
1 голос
/ 09 апреля 2019

Я знаю, как обновить тензор после индексации на его части, например:

import torch

b = torch.tensor([0, 1, 0, 1], dtype=torch.uint8)
b[b] = 2
b
# tensor([0, 2, 0, 2], dtype=torch.uint8)

, но есть ли способ, которым я могу обновить оригинальный тензор после индексации в нем дважды?Например,

i = 1
b = torch.tensor([0, 1, 0, 1], dtype=torch.uint8)
b[b][i] = 2
b
# tensor([0, 1, 0, 1], dtype=torch.uint8)

Мне бы хотелось, чтобы b было tensor([0, 1, 0, 2]) в конце.Есть ли способ сделать это?

Я знаю, что я могу сделать

masked = b[b]
masked[i] = 2
b[b] = masked
b
# tensor([0, 1, 0, 2], dtype=torch.uint8)

, но есть ли лучший способ?Кажется, что это должно быть неэффективно;если masked очень большой, я обновляю много мест в b, когда я действительно изменил только одну.

(Если другой подход, чем индексирование дважды, будет работать лучше, общая проблема, которую яесть, как изменить значение в исходном тензоре в i-ом местоположении маскированной версии этого тензора.)

1 Ответ

1 голос
/ 09 апреля 2019

Я принял другое решение из здесь и сравнил его с вашим решением:

Решение:

b[b.nonzero()[i]] = 2

Сравнение во время выполнения:

import torch as t
import numpy as np
import timeit


if __name__ == "__main__":

    np.random.seed(12345)
    b = t.tensor(np.random.randint(0,2, [1000]), dtype=t.uint8)
    # inconvenient way to think of a random index halfway that is 1.
    halfway = np.array(list(range(len(b))))[b == 1][len(b[b == 1]) //2]

    runs = 100000

    elapsed1 = timeit.timeit("mask=b[b]; mask[halfway] = 2; b[b] = mask", 
                             "from __main__ import b, halfway", number=runs)

    print("Time taken (original): {:.6f} ms per call".format(elapsed1 / runs))

    elapsed2 = timeit.timeit("b[b.nonzero()[halfway]]=2",
                             "from __main__ import b, halfway", number=runs)

    print("Time taken (improved): {:.6f} ms per call".format(elapsed2 / runs))

Результаты:

Time taken (original): 0.000096 ms per call
Time taken (improved): 0.000047 ms per call

Результаты для вектора длины 100000

Time taken: 0.010284 ms per call
Time taken: 0.003667 ms per call

Поэтому решения отличаются только в 2 раза. Я не уверен, что это оптимальное решение, но в зависимости от вашего размера (и от того, как часто вы вызываете функцию) это должно дать вам приблизительное представление о том, на что вы смотрите.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...