@@ -251,6 +251,7 @@ enum class Tag : size_t {
251251 ext_experimental, // device API using experimental feature
252252 host_perf, // host API for performance
253253 host_normal, // host API
254+ host_device, // host deivce API
254255 unsupported_warning,
255256 no_rewrite,
256257 tag_size
@@ -280,6 +281,8 @@ class MathRewriterFactory final : public CallExprRewriterFactoryBase {
280281 MathAPIRewriters[static_cast <size_t >(math::Tag::math_libdevice)];
281282 element_t &DeviceStdRewriter =
282283 MathAPIRewriters[static_cast <size_t >(math::Tag::device_std)];
284+ element_t &HostDeviceRewriter =
285+ MathAPIRewriters[static_cast <size_t >(math::Tag::host_device)];
283286 element_t &EmulationRewriter =
284287 MathAPIRewriters[static_cast <size_t >(math::Tag::emulation)];
285288 element_t &ExtExperimentalRewriter =
@@ -378,6 +381,9 @@ class MathRewriterFactory final : public CallExprRewriterFactoryBase {
378381 if (EmulationRewriter && EmulationRewriter.value ().first (C))
379382 return EmulationRewriter.value ().second .second ->create (C);
380383
384+ if (HostDeviceRewriter && HostDeviceRewriter.value ().first (C))
385+ return HostDeviceRewriter.value ().second .second ->create (C);
386+
381387 if (UnsupportedWarningRewriter &&
382388 UnsupportedWarningRewriter.value ().first (C))
383389 return UnsupportedWarningRewriter.value ().second .second ->create (C);
0 commit comments