@@ -346,7 +346,7 @@ void TypeInDeclRule::registerMatcher(MatchFinder &MF) {
346346 " cublasLtMatmulHeuristicResult_t" , " CUjit_target" ,
347347 " cublasLtMatrixTransformDesc_t" , " cudaGraphicsMapFlags" ,
348348 " cudaGraphicsRegisterFlags" , " cudaExternalMemoryHandleType" ,
349- " CUstreamCallback" ))))))
349+ " CUstreamCallback" , " cudaHostFn_t " ))))))
350350 .bind (" cudaTypeDef" ),
351351 this );
352352
@@ -4622,7 +4622,8 @@ void KernelCallRule::registerMatcher(ast_matchers::MatchFinder &MF) {
46224622 this );
46234623
46244624 auto launchAPIName = [&]() {
4625- return hasAnyName (" cudaLaunchKernel" , " cudaLaunchCooperativeKernel" );
4625+ return hasAnyName (" cudaLaunchKernel" , " cudaLaunchCooperativeKernel" ,
4626+ " cudaLaunchHostFunc" );
46264627 };
46274628 MF.addMatcher (
46284629 callExpr (allOf (callee (functionDecl (launchAPIName ())), parentStmt ()))
@@ -4837,56 +4838,89 @@ void KernelCallRule::runRule(
48374838 LaunchKernelCall = getNodeAsType<CallExpr>(Result, " launchUsed" );
48384839 IsAssigned = true ;
48394840 }
4840- if (!LaunchKernelCall)
4841+ auto FD = LaunchKernelCall->getDirectCallee ();
4842+ if (!LaunchKernelCall || !FD)
48414843 return ;
4842- const Expr *CalleeDRE = LaunchKernelCall->getArg (0 );
4843- bool IsFuncTypeErased = true ;
4844- auto QT = CalleeDRE->getType ();
4845-
4846- if (QT->isPointerType ()) {
4847- QT = QT->getPointeeType ();
4848- }
4849- if (QT->isFunctionType ()) {
4850- IsFuncTypeErased = false ;
4851- }
4852-
4853- if (!getAddressedRef (CalleeDRE)) {
4854- if (IsFuncTypeErased) {
4855- DpctGlobalInfo::setCVersionCUDALaunchUsed ();
4844+ std::string FuncName = FD->getNameAsString ();
4845+ std::cout << FuncName << std::endl;
4846+ if (FuncName == " cudaLaunchHostFunc" ) {
4847+ if (DpctGlobalInfo::getUsmLevel () != UsmLevel::UL_Restricted) {
4848+ report (LaunchKernelCall->getBeginLoc (), Diagnostics::API_NOT_MIGRATED,
4849+ false , " cudaLaunchHostFunc" );
4850+ return ;
48564851 }
48574852 std::string ReplStr;
48584853 llvm::raw_string_ostream OS (ReplStr);
4854+ std::string IndentStr = getIndent (LaunchKernelCall->getBeginLoc (),
4855+ DpctGlobalInfo::getSourceManager ())
4856+ .str ();
48594857 if (IsAssigned) {
48604858 OS << MapNames::getCheckErrorMacroName () << " (" ;
48614859 }
4862- OS << MapNames::getDpctNamespace () << " kernel_launcher::launch(" ;
4863- size_t ArgsNum = LaunchKernelCall->getNumArgs ();
4864- for (size_t i = 0 ; i < ArgsNum; i++) {
4865- if (auto Arg = LaunchKernelCall->getArg (i)) {
4866- if (i == 0 ) {
4867- if (auto E = getAddressedRef (CalleeDRE, false , nullptr )) {
4868- OS << ExprAnalysis::ref (E);
4860+ OS << ExprAnalysis::ref (LaunchKernelCall->getArg (0 ))
4861+ << " ->submit([&](sycl::handler &cgh) {" << getNL () << IndentStr
4862+ << " cgh.host_task([=](){" << getNL () << IndentStr << " "
4863+ << ExprAnalysis::ref (LaunchKernelCall->getArg (1 )) << " ("
4864+ << ExprAnalysis::ref (LaunchKernelCall->getArg (2 )) << " );" << getNL ()
4865+ << IndentStr << " });" << getNL () << IndentStr << " })" ;
4866+ if (IsAssigned) {
4867+ OS << " )" ;
4868+ }
4869+ auto Repl = new ReplaceStmt (LaunchKernelCall, OS.str ());
4870+ Repl->setBlockLevelFormatFlag ();
4871+ emplaceTransformation (Repl);
4872+ return ;
4873+ } else {
4874+ const Expr *CalleeDRE = LaunchKernelCall->getArg (0 );
4875+ bool IsFuncTypeErased = true ;
4876+ auto QT = CalleeDRE->getType ();
4877+
4878+ if (QT->isPointerType ()) {
4879+ QT = QT->getPointeeType ();
4880+ }
4881+ if (QT->isFunctionType ()) {
4882+ IsFuncTypeErased = false ;
4883+ }
4884+
4885+ if (!getAddressedRef (CalleeDRE)) {
4886+ if (IsFuncTypeErased) {
4887+ DpctGlobalInfo::setCVersionCUDALaunchUsed ();
4888+ }
4889+ std::string ReplStr;
4890+ llvm::raw_string_ostream OS (ReplStr);
4891+ if (IsAssigned) {
4892+ OS << MapNames::getCheckErrorMacroName () << " (" ;
4893+ }
4894+ OS << MapNames::getDpctNamespace () << " kernel_launcher::launch(" ;
4895+ size_t ArgsNum = LaunchKernelCall->getNumArgs ();
4896+ for (size_t i = 0 ; i < ArgsNum; i++) {
4897+ if (auto Arg = LaunchKernelCall->getArg (i)) {
4898+ if (i == 0 ) {
4899+ if (auto E = getAddressedRef (CalleeDRE, false , nullptr )) {
4900+ OS << ExprAnalysis::ref (E);
4901+ } else {
4902+ OS << ExprAnalysis::ref (Arg);
4903+ }
48694904 } else {
4870- OS << ExprAnalysis::ref (Arg);
4905+ OS << " , " << ExprAnalysis::ref (Arg);
48714906 }
4872- } else {
4873- OS << " , " << ExprAnalysis::ref (Arg);
48744907 }
48754908 }
4876- }
4877- OS << " )" ;
4878- if (IsAssigned) {
48794909 OS << " )" ;
4910+ if (IsAssigned) {
4911+ OS << " )" ;
4912+ }
4913+ emplaceTransformation (new ReplaceStmt (LaunchKernelCall, OS.str ()));
4914+ return ;
48804915 }
4881- emplaceTransformation (new ReplaceStmt (LaunchKernelCall, OS.str ()));
4882- return ;
4883- }
48844916
4885- if (!IsAssigned)
4886- findAndRemoveTrailingSemicolon (LaunchKernelCall, Result);
4887- if (DpctGlobalInfo::getInstance ().buildLaunchKernelInfo (LaunchKernelCall,
4888- IsAssigned)) {
4889- emplaceTransformation (new ReplaceStmt (LaunchKernelCall, true , false , " " ));
4917+ if (!IsAssigned)
4918+ findAndRemoveTrailingSemicolon (LaunchKernelCall, Result);
4919+ if (DpctGlobalInfo::getInstance ().buildLaunchKernelInfo (LaunchKernelCall,
4920+ IsAssigned)) {
4921+ emplaceTransformation (
4922+ new ReplaceStmt (LaunchKernelCall, true , false , " " ));
4923+ }
48904924 }
48914925 }
48924926}
0 commit comments