Конвертировать тензор PyTorch в список Python - PullRequest
0 голосов
/ 23 декабря 2018

Как мне преобразовать PyTorch Tensor в список Python?

Мой текущий пример использования - преобразовать тензор размером [1, 2048, 1, 1] в список из 2048 элементов.

Мой тензор имеет значения с плавающей запятой.Есть ли решение, которое также учитывает int и, возможно, другие типы данных?

1 Ответ

0 голосов
/ 23 декабря 2018

Я нашел Tensor.tolist(), который дает следующий пример использования:

>>> a = torch.randn(2, 2)
>>> a.tolist()
[[0.012766935862600803, 0.5415473580360413],
 [-0.08909505605697632, 0.7729271650314331]]
>>> a[0,0].tolist()
0.012766935862600803

Итак, чтобы ответить на вопрос, используйте a.squeeze().tolist(), чтобы удалить все измерения размера 1.

Также рассмотрите .flatten(), если список списков не нужен.


До того, как я наткнулся на .tolist(), я использовал:

list = [element.item() for element in tensor.flatten()]

Это выравнивает тензор в одном измерении, затем вызывает .item() для преобразования каждого элемента в число Python.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...