@@ -346,7 +346,7 @@ void TypeInDeclRule::registerMatcher(MatchFinder &MF) {
346346 " cublasLtMatmulHeuristicResult_t" , " CUjit_target" ,
347347 " cublasLtMatrixTransformDesc_t" , " cudaGraphicsMapFlags" ,
348348 " cudaGraphicsRegisterFlags" , " cudaExternalMemoryHandleType" ,
349- " CUstreamCallback" ))))))
349+ " CUstreamCallback" , " cudaHostFn_t " ))))))
350350 .bind (" cudaTypeDef" ),
351351 this );
352352
@@ -4367,11 +4367,12 @@ void StreamAPICallRule::runRule(const MatchFinder::MatchResult &Result) {
43674367 StreamName = " {{NEEDREPLACEQ" + std::to_string (Index) + " }}." ;
43684368 ReplStr = StreamName + " ext_oneapi_empty()" ;
43694369 } else {
4370- StreamName = getStmtSpelling (StreamArg);
4370+ ExprAnalysis EA (StreamArg);
4371+ ReplStr = EA.getReplacedString ();
43714372 if (needExtraParensInMemberExpr (StreamArg)) {
4372- StreamName = " (" + StreamName + " )" ;
4373+ ReplStr = " (" + ReplStr + " )" ;
43734374 }
4374- ReplStr = StreamName + " ->" + " ext_oneapi_empty()" ;
4375+ ReplStr = ReplStr + " ->" + " ext_oneapi_empty()" ;
43754376 }
43764377 if (IsAssigned) {
43774378 ReplStr = MapNames::getCheckErrorMacroName () + " ((" + ReplStr + " ))" ;
@@ -4414,7 +4415,12 @@ void StreamAPICallRule::runRule(const MatchFinder::MatchResult &Result) {
44144415
44154416 StmtStr0 = " {{NEEDREPLACEQ" + std::to_string (Index) + " }}." ;
44164417 } else {
4417- StmtStr0 = getStmtSpelling (CE->getArg (0 )) + " ->" ;
4418+ ExprAnalysis StreamArgEA (StreamArg);
4419+ StmtStr0 = StreamArgEA.getReplacedString ();
4420+ if (needExtraParensInMemberExpr (StreamArg)) {
4421+ StmtStr0 = " (" + StmtStr0 + " )" ;
4422+ }
4423+ StmtStr0 += " ->" ;
44184424 }
44194425 ReplStr = StmtStr0 + " ext_oneapi_submit_barrier({" +
44204426 StmtStr1 + " })" ;
@@ -4622,7 +4628,8 @@ void KernelCallRule::registerMatcher(ast_matchers::MatchFinder &MF) {
46224628 this );
46234629
46244630 auto launchAPIName = [&]() {
4625- return hasAnyName (" cudaLaunchKernel" , " cudaLaunchCooperativeKernel" );
4631+ return hasAnyName (" cudaLaunchKernel" , " cudaLaunchCooperativeKernel" ,
4632+ " cudaLaunchHostFunc" );
46264633 };
46274634 MF.addMatcher (
46284635 callExpr (allOf (callee (functionDecl (launchAPIName ())), parentStmt ()))
@@ -4837,56 +4844,89 @@ void KernelCallRule::runRule(
48374844 LaunchKernelCall = getNodeAsType<CallExpr>(Result, " launchUsed" );
48384845 IsAssigned = true ;
48394846 }
4840- if (!LaunchKernelCall)
4847+ auto FD = LaunchKernelCall->getDirectCallee ();
4848+ if (!LaunchKernelCall || !FD)
48414849 return ;
4842- const Expr *CalleeDRE = LaunchKernelCall->getArg (0 );
4843- bool IsFuncTypeErased = true ;
4844- auto QT = CalleeDRE->getType ();
4845-
4846- if (QT->isPointerType ()) {
4847- QT = QT->getPointeeType ();
4848- }
4849- if (QT->isFunctionType ()) {
4850- IsFuncTypeErased = false ;
4851- }
4852-
4853- if (!getAddressedRef (CalleeDRE)) {
4854- if (IsFuncTypeErased) {
4855- DpctGlobalInfo::setCVersionCUDALaunchUsed ();
4850+ std::string FuncName = FD->getNameAsString ();
4851+ std::cout << FuncName << std::endl;
4852+ if (FuncName == " cudaLaunchHostFunc" ) {
4853+ if (DpctGlobalInfo::getUsmLevel () != UsmLevel::UL_Restricted) {
4854+ report (LaunchKernelCall->getBeginLoc (), Diagnostics::API_NOT_MIGRATED,
4855+ false , " cudaLaunchHostFunc" );
4856+ return ;
48564857 }
48574858 std::string ReplStr;
48584859 llvm::raw_string_ostream OS (ReplStr);
4860+ std::string IndentStr = getIndent (LaunchKernelCall->getBeginLoc (),
4861+ DpctGlobalInfo::getSourceManager ())
4862+ .str ();
48594863 if (IsAssigned) {
48604864 OS << MapNames::getCheckErrorMacroName () << " (" ;
48614865 }
4862- OS << MapNames::getDpctNamespace () << " kernel_launcher::launch(" ;
4863- size_t ArgsNum = LaunchKernelCall->getNumArgs ();
4864- for (size_t i = 0 ; i < ArgsNum; i++) {
4865- if (auto Arg = LaunchKernelCall->getArg (i)) {
4866- if (i == 0 ) {
4867- if (auto E = getAddressedRef (CalleeDRE, false , nullptr )) {
4868- OS << ExprAnalysis::ref (E);
4866+ OS << ExprAnalysis::ref (LaunchKernelCall->getArg (0 ))
4867+ << " ->submit([&](sycl::handler &cgh) {" << getNL () << IndentStr
4868+ << " cgh.host_task([=](){" << getNL () << IndentStr << " "
4869+ << ExprAnalysis::ref (LaunchKernelCall->getArg (1 )) << " ("
4870+ << ExprAnalysis::ref (LaunchKernelCall->getArg (2 )) << " );" << getNL ()
4871+ << IndentStr << " });" << getNL () << IndentStr << " })" ;
4872+ if (IsAssigned) {
4873+ OS << " )" ;
4874+ }
4875+ auto Repl = new ReplaceStmt (LaunchKernelCall, OS.str ());
4876+ Repl->setBlockLevelFormatFlag ();
4877+ emplaceTransformation (Repl);
4878+ return ;
4879+ } else {
4880+ const Expr *CalleeDRE = LaunchKernelCall->getArg (0 );
4881+ bool IsFuncTypeErased = true ;
4882+ auto QT = CalleeDRE->getType ();
4883+
4884+ if (QT->isPointerType ()) {
4885+ QT = QT->getPointeeType ();
4886+ }
4887+ if (QT->isFunctionType ()) {
4888+ IsFuncTypeErased = false ;
4889+ }
4890+
4891+ if (!getAddressedRef (CalleeDRE)) {
4892+ if (IsFuncTypeErased) {
4893+ DpctGlobalInfo::setCVersionCUDALaunchUsed ();
4894+ }
4895+ std::string ReplStr;
4896+ llvm::raw_string_ostream OS (ReplStr);
4897+ if (IsAssigned) {
4898+ OS << MapNames::getCheckErrorMacroName () << " (" ;
4899+ }
4900+ OS << MapNames::getDpctNamespace () << " kernel_launcher::launch(" ;
4901+ size_t ArgsNum = LaunchKernelCall->getNumArgs ();
4902+ for (size_t i = 0 ; i < ArgsNum; i++) {
4903+ if (auto Arg = LaunchKernelCall->getArg (i)) {
4904+ if (i == 0 ) {
4905+ if (auto E = getAddressedRef (CalleeDRE, false , nullptr )) {
4906+ OS << ExprAnalysis::ref (E);
4907+ } else {
4908+ OS << ExprAnalysis::ref (Arg);
4909+ }
48694910 } else {
4870- OS << ExprAnalysis::ref (Arg);
4911+ OS << " , " << ExprAnalysis::ref (Arg);
48714912 }
4872- } else {
4873- OS << " , " << ExprAnalysis::ref (Arg);
48744913 }
48754914 }
4876- }
4877- OS << " )" ;
4878- if (IsAssigned) {
48794915 OS << " )" ;
4916+ if (IsAssigned) {
4917+ OS << " )" ;
4918+ }
4919+ emplaceTransformation (new ReplaceStmt (LaunchKernelCall, OS.str ()));
4920+ return ;
48804921 }
4881- emplaceTransformation (new ReplaceStmt (LaunchKernelCall, OS.str ()));
4882- return ;
4883- }
48844922
4885- if (!IsAssigned)
4886- findAndRemoveTrailingSemicolon (LaunchKernelCall, Result);
4887- if (DpctGlobalInfo::getInstance ().buildLaunchKernelInfo (LaunchKernelCall,
4888- IsAssigned)) {
4889- emplaceTransformation (new ReplaceStmt (LaunchKernelCall, true , false , " " ));
4923+ if (!IsAssigned)
4924+ findAndRemoveTrailingSemicolon (LaunchKernelCall, Result);
4925+ if (DpctGlobalInfo::getInstance ().buildLaunchKernelInfo (LaunchKernelCall,
4926+ IsAssigned)) {
4927+ emplaceTransformation (
4928+ new ReplaceStmt (LaunchKernelCall, true , false , " " ));
4929+ }
48904930 }
48914931 }
48924932}
0 commit comments