Skip to content

Commit 45bd042

Browse files
authored
Merge pull request #3240 from stan-dev/fix/constexpr-probs
Use `if constexpr` almost everywhere and cleanup type traits
2 parents b93e44d + c8c9130 commit 45bd042

396 files changed

Lines changed: 2781 additions & 2827 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

make/compiler_flags

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ CXXFLAGS_SUNDIALS ?= -pipe $(CXXFLAGS_OPTIM_SUNDIALS) $(CPPFLAGS_FLTO_SUNDIALS)
162162
# Update compiler flags with operating system specific modifications
163163
##
164164
ifeq ($(OS),Windows_NT)
165-
CXXFLAGS_WARNINGS ?= -Wall -Wno-unused-function -Wno-uninitialized -Wno-unused-but-set-variable -Wno-unused-variable -Wno-sign-compare -Wno-unused-local-typedefs -Wno-int-in-bool-context -Wno-attributes
165+
CXXFLAGS_WARNINGS ?= -Wall -Wno-template-id-cdtor -Wno-unused-function -Wno-uninitialized -Wno-unused-but-set-variable -Wno-unused-variable -Wno-sign-compare -Wno-unused-local-typedefs -Wno-int-in-bool-context -Wno-attributes
166166
CPPFLAGS_GTEST ?= -DGTEST_HAS_PTHREAD=0
167167
CPPFLAGS_OS ?= -D_USE_MATH_DEFINES
168168
CPPFLAGS_OS += -D_GLIBCXX11_USE_C99_COMPLEX

stan/math/fwd/fun/hypergeometric_1F0.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ FvarT hypergeometric_1F0(const Ta& a, const Tz& z) {
3535
partials_type_t<Ta> a_val = value_of(a);
3636
partials_type_t<Tz> z_val = value_of(z);
3737
FvarT rtn = FvarT(hypergeometric_1F0(a_val, z_val), 0.0);
38-
if (!is_constant_all<Ta>::value) {
38+
if constexpr (is_autodiff_v<Ta>) {
3939
rtn.d_ += forward_as<FvarT>(a).d() * -rtn.val() * log1m(z_val);
4040
}
41-
if (!is_constant_all<Tz>::value) {
41+
if constexpr (is_autodiff_v<Tz>) {
4242
rtn.d_ += forward_as<FvarT>(z).d() * rtn.val() * a_val * inv(1 - z_val);
4343
}
4444
return rtn;

stan/math/fwd/fun/hypergeometric_2F1.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,16 @@ inline return_type_t<Ta1, Ta2, Tb, Tz> hypergeometric_2F1(const Ta1& a1,
4545

4646
typename fvar_t::Scalar grad = 0;
4747

48-
if (!is_constant<Ta1>::value) {
48+
if constexpr (is_autodiff_v<Ta1>) {
4949
grad += forward_as<fvar_t>(a1).d() * std::get<0>(grad_tuple);
5050
}
51-
if (!is_constant<Ta2>::value) {
51+
if constexpr (is_autodiff_v<Ta2>) {
5252
grad += forward_as<fvar_t>(a2).d() * std::get<1>(grad_tuple);
5353
}
54-
if (!is_constant<Tb>::value) {
54+
if constexpr (is_autodiff_v<Tb>) {
5555
grad += forward_as<fvar_t>(b).d() * std::get<2>(grad_tuple);
5656
}
57-
if (!is_constant<Tz>::value) {
57+
if constexpr (is_autodiff_v<Tz>) {
5858
grad += forward_as<fvar_t>(z).d() * std::get<3>(grad_tuple);
5959
}
6060

stan/math/fwd/fun/hypergeometric_pFq.hpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,11 @@ namespace math {
2525
* @param[in] z Scalar z argument
2626
* @return Generalized hypergeometric function
2727
*/
28-
template <typename Ta, typename Tb, typename Tz,
29-
typename FvarT = return_type_t<Ta, Tb, Tz>,
30-
bool grad_a = !is_constant<Ta>::value,
31-
bool grad_b = !is_constant<Tb>::value,
32-
bool grad_z = !is_constant<Tz>::value,
33-
require_all_vector_t<Ta, Tb>* = nullptr,
34-
require_fvar_t<FvarT>* = nullptr>
28+
template <
29+
typename Ta, typename Tb, typename Tz,
30+
typename FvarT = return_type_t<Ta, Tb, Tz>, bool grad_a = is_autodiff_v<Ta>,
31+
bool grad_b = is_autodiff_v<Tb>, bool grad_z = is_autodiff_v<Tz>,
32+
require_all_vector_t<Ta, Tb>* = nullptr, require_fvar_t<FvarT>* = nullptr>
3533
inline FvarT hypergeometric_pFq(Ta&& a, Tb&& b, Tz&& z) {
3634
auto&& a_ref = to_ref(as_column_vector_or_scalar(a));
3735
auto&& b_ref = to_ref(as_column_vector_or_scalar(b));

stan/math/fwd/fun/inv_inc_beta.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ inline fvar<partials_return_t<T1, T2, T3>> inv_inc_beta(const T1& a,
5555

5656
T_return inv_d_(0);
5757

58-
if (is_fvar<T1>::value) {
58+
if constexpr (is_fvar<T1>::value) {
5959
std::vector<T_return> da_a{a_val, a_val, one_m_b};
6060
std::vector<T_return> da_b{ap1, ap1};
6161
auto da1 = exp(one_m_b * log1m_w + one_m_a * log_w);
@@ -66,7 +66,7 @@ inline fvar<partials_return_t<T1, T2, T3>> inv_inc_beta(const T1& a,
6666
inv_d_ += forward_as<fvar<T_return>>(a).d_ * da1 * (da2 - da3);
6767
}
6868

69-
if (is_fvar<T2>::value) {
69+
if constexpr (is_fvar<T2>::value) {
7070
std::vector<T_return> db_a{b_val, b_val, one_m_a};
7171
std::vector<T_return> db_b{bp1, bp1};
7272
auto db1 = (w - 1) * exp(-b_val * log1m_w + one_m_a * log_w);
@@ -79,7 +79,7 @@ inline fvar<partials_return_t<T1, T2, T3>> inv_inc_beta(const T1& a,
7979
inv_d_ += forward_as<fvar<T_return>>(b).d_ * db1 * (exp(db2) - db3);
8080
}
8181

82-
if (is_fvar<T3>::value) {
82+
if constexpr (is_fvar<T3>::value) {
8383
inv_d_ += forward_as<fvar<T_return>>(p).d_
8484
* exp(one_m_b * log1m_w + one_m_a * log_w + lbeta_ab);
8585
}

stan/math/fwd/fun/log_mix.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,16 @@ inline void log_mix_partial_helper(
4141
= 1.0 / t_plus_one_m_t_prod_exp_lam2_m_lam1;
4242

4343
unsigned int offset = 0;
44-
if (std::is_same<T_theta, partial_return_type>::value) {
44+
if constexpr (std::is_same<T_theta, partial_return_type>::value) {
4545
partials_array[offset]
4646
= one_m_exp_lam2_m_lam1 * one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1;
4747
++offset;
4848
}
49-
if (std::is_same<T_lambda1, partial_return_type>::value) {
49+
if constexpr (std::is_same<T_lambda1, partial_return_type>::value) {
5050
partials_array[offset] = theta * one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1;
5151
++offset;
5252
}
53-
if (std::is_same<T_lambda2, partial_return_type>::value) {
53+
if constexpr (std::is_same<T_lambda2, partial_return_type>::value) {
5454
partials_array[offset] = one_m_t_prod_exp_lam2_m_lam1
5555
* one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1;
5656
}

stan/math/fwd/functor/integrate_1d.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,17 @@ inline return_type_t<T_a, T_b, Args...> integrate_1d_impl(
4545
FvarT ret = finite_diff(func, args...);
4646

4747
// Calculate tangents w.r.t. integration bounds if needed
48-
if (is_fvar<T_a>::value || is_fvar<T_b>::value) {
48+
if constexpr (is_fvar<T_a>::value || is_fvar<T_b>::value) {
4949
auto val_args = std::make_tuple(value_of(args)...);
50-
if (is_fvar<T_a>::value) {
50+
if constexpr (is_fvar<T_a>::value) {
5151
ret.d_ += math::forward_as<FvarT>(a).d_
5252
* math::apply(
5353
[&](auto &&... tuple_args) {
5454
return -f(a_val, 0.0, msgs, tuple_args...);
5555
},
5656
val_args);
5757
}
58-
if (is_fvar<T_b>::value) {
58+
if constexpr (is_fvar<T_b>::value) {
5959
ret.d_ += math::forward_as<FvarT>(b).d_
6060
* math::apply(
6161
[&](auto &&... tuple_args) {

stan/math/mix/functor/laplace_marginal_density.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1075,7 +1075,7 @@ inline auto laplace_marginal_density(const LLFun& ll_fun, LLTupleArgs&& ll_args,
10751075
auto ll_args_filter = internal::filter_var_scalar_types(ll_args_copy);
10761076
stan::math::for_each(
10771077
[](auto&& output_i, auto&& ll_arg_i) {
1078-
if (is_any_var_scalar_v<decltype(ll_arg_i)>) {
1078+
if constexpr (is_any_var_scalar_v<decltype(ll_arg_i)>) {
10791079
internal::collect_adjoints<true>(output_i, ll_arg_i);
10801080
}
10811081
},

stan/math/opencl/kernel_generator/broadcast.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ class broadcast_
4242
*/
4343
explicit broadcast_(T&& a) : base(std::forward<T>(a)) {
4444
const char* function = "broadcast";
45-
if (Colwise) {
45+
if constexpr (Colwise) {
4646
check_size_match(function, "Rows of ", "a", a.rows(), "", "", 1);
4747
}
48-
if (Rowwise) {
48+
if constexpr (Rowwise) {
4949
check_size_match(function, "Columns of ", "a", a.cols(), "", "", 1);
5050
}
5151
}
@@ -67,10 +67,10 @@ class broadcast_
6767
*/
6868
inline void modify_argument_indices(std::string& row_index_name,
6969
std::string& col_index_name) const {
70-
if (Colwise) {
70+
if constexpr (Colwise) {
7171
row_index_name = "0";
7272
}
73-
if (Rowwise) {
73+
if constexpr (Rowwise) {
7474
col_index_name = "0";
7575
}
7676
}

stan/math/opencl/kernel_generator/calc_if.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class calc_if_
4444
const std::string& col_index_name,
4545
const bool view_handled,
4646
const std::string& var_name_arg) const {
47-
if (Do_Calculate) {
47+
if constexpr (Do_Calculate) {
4848
var_name_ = var_name_arg;
4949
}
5050
return {};
@@ -70,7 +70,7 @@ class calc_if_
7070
std::unordered_map<const void*, const char*>& generated_all,
7171
name_generator& ng, const std::string& row_index_name,
7272
const std::string& col_index_name, const T_result& result) const {
73-
if (Do_Calculate) {
73+
if constexpr (Do_Calculate) {
7474
return this->template get_arg<0>().get_whole_kernel_parts(
7575
generated, generated_all, ng, row_index_name, col_index_name, result);
7676
} else {
@@ -92,7 +92,7 @@ class calc_if_
9292
std::unordered_map<const void*, const char*>& generated,
9393
std::unordered_map<const void*, const char*>& generated_all,
9494
cl::Kernel& kernel, int& arg_num) const {
95-
if (Do_Calculate) {
95+
if constexpr (Do_Calculate) {
9696
this->template get_arg<0>().set_args(generated, generated_all, kernel,
9797
arg_num);
9898
}

0 commit comments

Comments
 (0)