Я сталкиваюсь с проблемой, когда применяю шаблонные методы в функциях оболочки ядра.
Вот коды в моем первоначальном сознании:
//----------------------------------------
// cuda_demo.cuh
template<typename T>
void kernel_wrapper(T param);
//----------------------------------------
// cuda_demo.cu
#include <cuda.h>
#include <cuda_runtime.h>
#include "cuda_demo.cuh"
template<typename T>
__global__ void my_kernel(T param) {
// do something
}
template<typename T>
void kernel_wrapper(T param) {
my_kernel<<<1,1>>>(param);
}
//----------------------------------------
// main.cpp
#include "cuda_demo.cuh"
int main() {
int param = 10;
kernel_wrapper(param);
return 0;
}
Вскоре я обнаружил, что шаблоны должны быть реализованы в файле заголовка (см. Почему шаблоны могут быть реализованы только в заголовкефайл ).
И из этого я получаю два решения, общее из которых - «записать объявление шаблона в заголовочный файл, затем реализовать класс в файле реализации (например, .tpp) и включить этот файл реализации вконец заголовка ".
Поэтому я меняю коды:
//----------------------------------------
// cuda_demo.cuh
template<typename T>
void kernel_wrapper(T param);
#include "cuda_demo.cu"
//----------------------------------------
// cuda_demo.cu
#include <cuda.h>
#include <cuda_runtime.h>
template<typename T>
__global__ void my_kernel(T param) {
// do something
}
template<typename T>
void kernel_wrapper(T param) {
my_kernel<<<1,1>>>(param);
}
Компилятор выдает мне следующую ошибку:
error: expected primary-expression before < token
my_kernel<<<1,1>>>(param);
Такая же ошибка возникает, когдаЯ поместил все коды cuda в "cuda_demo.cuh".
Затем я попробовал второе решение следующим образом:
//----------------------------------------
// cuda_demo.cuh
template<typename T>
void kernel_wrapper(T param);
//----------------------------------------
// cuda_demo.cu
#include <cuda.h>
#include <cuda_runtime.h>
#include "cuda_demo.cuh"
template<typename T>
__global__ void my_kernel(T param) {
// do something
}
template<typename T>
void kernel_wrapper(T param) {
my_kernel<<<1,1>>>(param);
}
template void kernel_wrapper<int>(int param);
Это работает хорошо!Но в моем проекте «T» не является простым типом, который может быть рекурсивным, например
Class_1<Class_2<Class_3<...>>>,
. Это означает, что я не могу заранее определить конкретный тип «T».
Кто-нибудь знает, как это решить?
Спасибо.