Я хотел бы написать функцию, которая оборачивает MPI_Allreduce и которая принимает любой двоичный оператор (как std :: redu), который будет использоваться MPI как оператор сокращения. В частности, пользователь такой функции может использовать лямбду.
Следующий простой пример кода показывает, что:
#include <mpi.h>
#include <iostream>
#include <functional>
template<typename BinaryOp>
void reduce(double *data, int len, BinaryOp op) {
auto lambda=[op](void *a, void *b, int *len, MPI_Datatype *){
double *aa=static_cast<double *>(a);
double *bb=static_cast<double *>(bb);
for (int i=0; i<*len; ++i) {
bb[i]=op(aa[i], bb[i]);
}
};
// MPI_User_function is a typedef to: void (MPI_User_function) ( void * a, void * b, int * len, MPI_Datatype * )
MPI_User_function *opPtr=/* black magic code that get the function pointer from the lambda */;
MPI_Op mpiOp;
MPI_Op_create(*opPtr, 1, &mpiOp);
MPI_Allreduce(MPI_IN_PLACE, data, len, MPI_DOUBLE, mpiOp, MPI_COMM_WORLD);
MPI_Op_free(&mpiOp);
}
int main() {
MPI_Init(nullptr, nullptr);
double data[4]={1.,2.,3.,4.};
reduce(data, 4, [](double a, double b){return a+b;});
int pRank;
MPI_Comm_rank(MPI_COMM_WORLD, &pRank);
if (pRank==0) {
for (int i=0; i<4; ++i) {
std::cout << data[i] << " ";
}
std::cout << std::endl;
}
MPI_Finalize();
return 1;
}
Недостающая часть - это код, который получает указатель на функцию от лямбды в функции reduce
. Из нескольких связанных вопросов эта проблема получения указателя на функцию из захвата лямбды кажется сложной, но ее можно решить. Но мне не удалось заставить что-то работать с этим простым кодом (я попробовал несколько хитростей с std :: function, std :: bind, хранением лямбда-выражения в переменной stati c) ... Так что небольшая помощь была бы полезна!
EDIT: После ответа @noma я попробовал следующий упрощенный код без MPI в goldbolt
#include <iostream>
#include <functional>
typedef double MPI_Datatype;
template<typename BinaryOp, BinaryOp op> // older standards
void non_lambda(void *a, void *b, int *len, MPI_Datatype *)
{}
template<typename BinaryOp>
void reduce(double *data, int len, BinaryOp op) {
typedef void (MPI_User_function) ( void * a, void * b, int * len, MPI_Datatype * );
MPI_User_function *opPtr = &non_lambda<decltype(+op), +op>; // older standards;
}
int main() {
double data[4]={1.,2.,3.,4.};
reduce(data, 4, [](double a, double b){return a+b;});
return 1;
}
Он компилируется на некоторых компиляторах. Вот результаты:
- i cc> = 19.0.1 (с -std = c ++ 17): OK
- clang ++> = 5.0.0 (с - -std = c ++ 17): OK
- clang ++ 10.0.0 (с --std = c ++ 14): NOK
- g ++ 9.3 (с --std = c ++ 17): NOK
- i cc> = 19.0.0 (с -std = c ++ 17): NOK
Сообщение об ошибке при i cc 19.0. 0 с -std = c ++ 17 (или i cc 19.0.1 с -std = c ++ 14) интересно:
<source>(15): error: expression must have a constant value
MPI_User_function *opPtr = &non_lambda<decltype(+op), +op>; // older standards;
^
detected during instantiation of "void reduce(double *, int, BinaryOp) [with BinaryOp=lambda [](double, double)->double]" at line 21
И действительно, я не совсем понимаю прохождение переменной 'op', являющейся аргументом времени выполнения функции reduce
в качестве второго параметра шаблона функции non_lambda
... Является ли это неясной функциональностью c ++ 17, которую поддерживают только некоторые компиляторы?