@@ -4512,10 +4512,14 @@ void StreamAPICallRule::runRule(const MatchFinder::MatchResult &Result) {
45124512}
45134513
45144514void KernelCallRefRule::registerMatcher (ast_matchers::MatchFinder &MF) {
4515- MF.addMatcher (declRefExpr (allOf (to (functionDecl (hasAttr (attr::CUDAGlobal))),
4516- unless (hasAncestor (cudaKernelCallExpr ()))))
4517- .bind (" kernelRef" ),
4518- this );
4515+ MF.addMatcher (
4516+ functionDecl (
4517+ forEachDescendant (
4518+ declRefExpr (allOf (to (functionDecl (hasAttr (attr::CUDAGlobal))),
4519+ unless (hasAncestor (cudaKernelCallExpr ()))))
4520+ .bind (" kernelRef" )))
4521+ .bind (" outerFunc" ),
4522+ this );
45194523 MF.addMatcher (unresolvedLookupExpr (unless (hasAncestor (cudaKernelCallExpr ())))
45204524 .bind (" unresolvedRef" ),
45214525 this );
@@ -4591,6 +4595,11 @@ void KernelCallRefRule::insertWrapperPostfix(const T *Node,
45914595void KernelCallRefRule::runRule (
45924596 const ast_matchers::MatchFinder::MatchResult &Result) {
45934597 if (auto DRE = getAssistNodeAsType<DeclRefExpr>(Result, " kernelRef" )) {
4598+ const FunctionDecl *OuterFD =
4599+ getAssistNodeAsType<FunctionDecl>(Result, " outerFunc" );
4600+ if (!OuterFD) {
4601+ return ;
4602+ }
45944603 if (auto ParentCE = DpctGlobalInfo::findAncestor<CallExpr>(DRE)) {
45954604 if (auto Callee = ParentCE->getDirectCallee ()) {
45964605 if (dpct::DpctGlobalInfo::isInCudaPath (Callee->getBeginLoc ())) {
@@ -4614,31 +4623,25 @@ void KernelCallRefRule::runRule(
46144623 DFI->collectInfoForWrapper (FD);
46154624 }
46164625 }
4617- if (auto *OuterFD = DpctGlobalInfo::findAncestor<FunctionDecl>(DRE)) {
4618- if ((OuterFD->getTemplatedKind () ==
4619- FunctionDecl::TemplatedKind::TK_NonTemplate) ||
4620- (OuterFD->getTemplatedKind () ==
4621- FunctionDecl::TemplatedKind::TK_FunctionTemplate)) {
4622- std::string TypeRepl;
4623- if (DpctGlobalInfo::isCVersionCUDALaunchUsed ()) {
4624- if ((IsTemplateRelated &&
4625- (!DRE->hasExplicitTemplateArgs () ||
4626- (DRE->getNumTemplateArgs () <= TemplateParamNum))) ||
4627- DRE->hadMultipleCandidates ()) {
4628- TypeRepl = getTypeRepl (DRE);
4629- }
4626+ if ((OuterFD->getTemplatedKind () ==
4627+ FunctionDecl::TemplatedKind::TK_NonTemplate) ||
4628+ (OuterFD->getTemplatedKind () ==
4629+ FunctionDecl::TemplatedKind::TK_FunctionTemplate)) {
4630+ std::string TypeRepl;
4631+ if (DpctGlobalInfo::isCVersionCUDALaunchUsed ()) {
4632+ if ((IsTemplateRelated &&
4633+ (!DRE->hasExplicitTemplateArgs () ||
4634+ (DRE->getNumTemplateArgs () <= TemplateParamNum))) ||
4635+ DRE->hadMultipleCandidates ()) {
4636+ TypeRepl = getTypeRepl (DRE);
46304637 }
4631- insertWrapperPostfix<DeclRefExpr>(
4632- DRE, std::move (TypeRepl),
4633- DpctGlobalInfo::isCVersionCUDALaunchUsed ());
46344638 }
4639+ insertWrapperPostfix<DeclRefExpr>(
4640+ DRE, std::move (TypeRepl), DpctGlobalInfo::isCVersionCUDALaunchUsed ());
46354641 }
46364642 }
46374643 if (auto ULE =
46384644 getAssistNodeAsType<UnresolvedLookupExpr>(Result, " unresolvedRef" )) {
4639- if (!DpctGlobalInfo::isCVersionCUDALaunchUsed ()) {
4640- return ;
4641- }
46424645 bool KernelRefFound = false ;
46434646 for (auto *D : ULE->decls ()) {
46444647 const FunctionDecl *FD = dyn_cast<FunctionDecl>(D);
@@ -4670,7 +4673,8 @@ void KernelCallRefRule::runRule(
46704673 }
46714674 }
46724675 }
4673- insertWrapperPostfix<UnresolvedLookupExpr>(ULE, getTypeRepl (ULE), true );
4676+ insertWrapperPostfix<UnresolvedLookupExpr>(
4677+ ULE, getTypeRepl (ULE), DpctGlobalInfo::isCVersionCUDALaunchUsed ());
46744678 }
46754679}
46764680
0 commit comments