@@ -189,13 +189,43 @@ inline auto UseBFloat16 = [](const CallExpr *C) -> bool {
189189 return DpctGlobalInfo::useBFloat16 ();
190190};
191191
192+ inline auto IsDirectCallerPureDevice = [](const CallExpr *C) -> bool {
193+ auto ContextFD = getImmediateOuterFuncDecl (C);
194+ while (auto LE = getImmediateOuterLambdaExpr (ContextFD)) {
195+ ContextFD = getImmediateOuterFuncDecl (LE);
196+ }
197+ if (!ContextFD)
198+ return false ;
199+ if ((ContextFD->getAttr <CUDADeviceAttr>() &&
200+ !ContextFD->getAttr <CUDAHostAttr>()) ||
201+ ContextFD->getAttr <CUDAGlobalAttr>()) {
202+ return true ;
203+ }
204+ return false ;
205+ };
206+
207+ inline auto IsDirectCallerPureHost = [](const CallExpr *C) -> bool {
208+ auto ContextFD = getImmediateOuterFuncDecl (C);
209+ while (auto LE = getImmediateOuterLambdaExpr (ContextFD)) {
210+ ContextFD = getImmediateOuterFuncDecl (LE);
211+ }
212+ if (!ContextFD)
213+ return false ;
214+ if (!ContextFD->getAttr <CUDADeviceAttr>() &&
215+ !ContextFD->getAttr <CUDAGlobalAttr>()) {
216+ return true ;
217+ }
218+ return false ;
219+ };
220+
192221inline auto IsPureHost = [](const CallExpr *C) -> bool {
193222 const FunctionDecl *FD = C->getDirectCallee ();
194223 if (!FD)
195224 return false ;
225+ if (!IsDirectCallerPureHost (C))
226+ return false ;
196227 if (!(FD->hasAttr <CUDADeviceAttr>()))
197228 return true ;
198-
199229 SourceLocation DeclLoc =
200230 dpct::DpctGlobalInfo::getSourceManager ().getExpansionLoc (
201231 FD->getLocation ());
@@ -209,22 +239,12 @@ inline auto IsPureHost = [](const CallExpr *C) -> bool {
209239 }
210240 return false ;
211241};
212- inline auto IsPureDevice = makeCheckAnd(
213- HasDirectCallee (),
214- makeCheckAnd (IsDirectCalleeHasAttribute<CUDADeviceAttr>(),
215- makeCheckNot (IsDirectCalleeHasAttribute<CUDAHostAttr>())));
216-
217- inline auto IsDirectCallerPureDevice = [](const CallExpr *C) -> bool {
218- auto ContextFD = getImmediateOuterFuncDecl (C);
219- while (auto LE = getImmediateOuterLambdaExpr (ContextFD)) {
220- ContextFD = getImmediateOuterFuncDecl (LE);
221- }
222- if (!ContextFD)
242+ inline auto IsPureDevice = [](const CallExpr *C) -> bool {
243+ if (!HasDirectCallee ()(C))
223244 return false ;
224- if (ContextFD-> getAttr <CUDADeviceAttr>() &&
225- !ContextFD-> getAttr <CUDAHostAttr>()) {
245+ if (IsDirectCalleeHasAttribute <CUDADeviceAttr>()(C ) &&
246+ !IsDirectCalleeHasAttribute <CUDAHostAttr>()(C))
226247 return true ;
227- }
228248 return false ;
229249};
230250inline auto IsUnresolvedLookupExpr = [](const CallExpr *C) -> bool {
@@ -344,8 +364,9 @@ class MathRewriterFactory final : public CallExprRewriterFactoryBase {
344364 // 4. math_libdevice
345365 // 5. device_std
346366 // c. Host and device
347- // 1. emulation
348- // 2. unsupported_warning
367+ // 1. host_device
368+ // 2. emulation
369+ // 3. unsupported_warning
349370 std::shared_ptr<CallExprRewriter> create (const CallExpr *C) const override {
350371 if (math::IsPureHost (C)) {
351372 // HOST
@@ -355,6 +376,8 @@ class MathRewriterFactory final : public CallExprRewriterFactoryBase {
355376 return HostPerfRewriter.value ().second .second ->create (C);
356377 if (HostNormalRewriter && HostNormalRewriter.value ().first (C))
357378 return HostNormalRewriter.value ().second .second ->create (C);
379+ } else {
380+ return NoRewriteRewriter.value ().second .second ->create (C);
358381 }
359382 } else {
360383 // DEVICE
@@ -378,12 +401,12 @@ class MathRewriterFactory final : public CallExprRewriterFactoryBase {
378401 }
379402
380403 // Host and device
381- if (EmulationRewriter && EmulationRewriter.value ().first (C))
382- return EmulationRewriter.value ().second .second ->create (C);
383-
384404 if (HostDeviceRewriter && HostDeviceRewriter.value ().first (C))
385405 return HostDeviceRewriter.value ().second .second ->create (C);
386406
407+ if (EmulationRewriter && EmulationRewriter.value ().first (C))
408+ return EmulationRewriter.value ().second .second ->create (C);
409+
387410 if (UnsupportedWarningRewriter &&
388411 UnsupportedWarningRewriter.value ().first (C))
389412 return UnsupportedWarningRewriter.value ().second .second ->create (C);
0 commit comments