Удалить LongTensor из набора LongTensor в Python / Pytorch - PullRequest
1 голос
/ 29 апреля 2020

У меня есть набор, содержащий несколько LongTensors, и мне нужно удалить некоторые из LongTensors из набора, есть ли эффективный способ сделать это в pytorch?

import torch
ks = {torch.LongTensor([1, 3]), torch.LongTensor([2, 3]), torch.LongTensor([3, 3])}
p = torch.LongTensor([1, 3])
ks.remove(p)

Приведенный выше метод возвращает

KeyError: тензор ([1, 3])

Есть ли эффективный способ удаления тензора ??

1 Ответ

0 голосов
/ 29 апреля 2020

Выполните итерацию по списку , чтобы сравнить каждый элемент по одному, а затем преобразовать list обратно в set.

>>> set([x for x in ks if (x!=p).any()])
{tensor([3, 3]), tensor([2, 3])}
...