@@ -388,29 +388,29 @@ class MathRewriterFactory final : public CallExprRewriterFactoryBase {
388388 return Rewriter.value ();
389389 }
390390 }
391- if (math::IsUnresolvedLookupExpr (C)) {
392- if (math::IsDirectCallerPureDevice (C)) {
393- if (Rewriter = getDeviceRewriter (C))
394- return Rewriter.value ();
395- }
396- }
397391 if (math::IsDefinedInCUDA ()(C)) {
398392 if (Rewriter = getDeviceRewriter (C))
399393 return Rewriter.value ();
400394 }
401395 }
402-
403- // Host and device
404- if (HostDeviceRewriter && HostDeviceRewriter.value ().first (C))
405- return HostDeviceRewriter.value ().second .second ->create (C);
406-
407- if (EmulationRewriter && EmulationRewriter.value ().first (C))
408- return EmulationRewriter.value ().second .second ->create (C);
409-
410- if (UnsupportedWarningRewriter &&
411- UnsupportedWarningRewriter.value ().first (C))
412- return UnsupportedWarningRewriter.value ().second .second ->create (C);
413-
396+ if (math::IsUnresolvedLookupExpr (C)) {
397+ if (math::IsDirectCallerPureDevice (C)) {
398+ if (auto Rewriter = getDeviceRewriter (C))
399+ return Rewriter.value ();
400+ } else if (math::IsDirectCallerPureHost (C)) {
401+ return NoRewriteRewriter.value ().second .second ->create (C);
402+ }
403+ }
404+ if (!math::IsDefinedByUser ()(C)) {
405+ // Host and device
406+ if (HostDeviceRewriter && HostDeviceRewriter.value ().first (C))
407+ return HostDeviceRewriter.value ().second .second ->create (C);
408+ if (EmulationRewriter && EmulationRewriter.value ().first (C))
409+ return EmulationRewriter.value ().second .second ->create (C);
410+ if (UnsupportedWarningRewriter &&
411+ UnsupportedWarningRewriter.value ().first (C))
412+ return UnsupportedWarningRewriter.value ().second .second ->create (C);
413+ }
414414 return NoRewriteRewriter.value ().second .second ->create (C);
415415 }
416416};
0 commit comments