Расширения PyTorch C ++: доступ к данным для полутензорных - PullRequest
1 голос
/ 29 октября 2019

Я пытаюсь написать расширение C ++ / CUDA для PyTorch с использованием C ++ Tensor API, и я хотел бы, чтобы мой код работал как с float32, так и с float16 (половинная точность). Я не уверен, как получить доступ к указателю данных для половинных тензоров, поступающих из Python.

Вот как я делаю это для тензоров с плавающей точкой:

// Access data pointer for float Tensor A
torch::Tensor A;
float* ptr = A.data<float>();

Вот что я пробовал для половинытензор:

// CUDA float 16 type
// undefined symbol: _ZNK2at6Tensor4dataI6__halfEEPT_v
A.data<__half>();

// PyTorch float16 type
// error: no instance of function template "at::Tensor::data" 
A.data<torch::ScalarType::Half>();

// Casting to __half*
// This compiles but throws and error if the requested pointer type doesn't match the Tensor type:
// RuntimeError: expected scalar type Float but found Half
(__half*)(A.data<float>());

Я попытался просмотреть исходный код API C ++, но не смог найти ничего похожего на тип float16.

Информация о системе: Python 3.6.2 PyTorch 1.00,1

1 Ответ

0 голосов
/ 01 ноября 2019

Правильный тип оказался at::Half.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...