diff --git a/src/ompi/impl/all_reduce.h b/src/ompi/impl/all_reduce.h index baf5a11..c89edb7 100644 --- a/src/ompi/impl/all_reduce.h +++ b/src/ompi/impl/all_reduce.h @@ -1,6 +1,8 @@ #ifndef INFINI_CCL_OMPI_IMPL_ALL_REDUCE_H_ #define INFINI_CCL_OMPI_IMPL_ALL_REDUCE_H_ +#include + #include "base/all_reduce.h" #include "communicator.h" #include "dispatcher.h" @@ -66,7 +68,14 @@ class AllReduceImpl { for (size_t i = 0; i < count; ++i) { // TODO(lzm): should later use the unified `Cast` function instead of // static_cast to support CPU custom types. - typed_buf[i] *= static_cast(scale); + if constexpr (std::is_integral_v) { + // Scale in floating point first; casting `scale` to an integer + // type would truncate it to `0` and zero out the result. + typed_buf[i] = + static_cast(static_cast(typed_buf[i]) * scale); + } else { + typed_buf[i] *= static_cast(scale); + } } }); } diff --git a/src/ompi/impl/reduce_scatter.h b/src/ompi/impl/reduce_scatter.h index 88854fc..983f5cf 100644 --- a/src/ompi/impl/reduce_scatter.h +++ b/src/ompi/impl/reduce_scatter.h @@ -2,6 +2,7 @@ #define INFINI_CCL_OMPI_IMPL_REDUCE_SCATTER_H_ #include +#include #include "base/reduce_scatter.h" #include "communicator.h" @@ -79,7 +80,14 @@ class ReduceScatterImpl { for (size_t i = 0; i < recv_count; ++i) { // TODO(lzm): should later use the unified `Cast` function instead of // static_cast to support CPU custom types. - typed_buf[i] *= static_cast(scale); + if constexpr (std::is_integral_v) { + // Scale in floating point first; casting `scale` to an integer + // type would truncate it to `0` and zero out the result. + typed_buf[i] = + static_cast(static_cast(typed_buf[i]) * scale); + } else { + typed_buf[i] *= static_cast(scale); + } } }); }