Я пытаюсь написать расширение C ++ / CUDA для PyTorch с использованием C ++ Tensor API, и я хотел бы, чтобы мой код работал как с float32, так и с float16 (половинная точность). Я не уверен, как получить доступ к указателю данных для половинных тензоров, поступающих из Python.
Вот как я делаю это для тензоров с плавающей точкой:
// Access data pointer for float Tensor A
torch::Tensor A;
float* ptr = A.data<float>();
Вот что я пробовал для половинытензор:
// CUDA float 16 type
// undefined symbol: _ZNK2at6Tensor4dataI6__halfEEPT_v
A.data<__half>();
// PyTorch float16 type
// error: no instance of function template "at::Tensor::data"
A.data<torch::ScalarType::Half>();
// Casting to __half*
// This compiles but throws and error if the requested pointer type doesn't match the Tensor type:
// RuntimeError: expected scalar type Float but found Half
(__half*)(A.data<float>());
Я попытался просмотреть исходный код API C ++, но не смог найти ничего похожего на тип float16.
Информация о системе: Python 3.6.2 PyTorch 1.00,1