Skip to content

Commit 67021e0

Browse files
committed
update for cpplint
1 parent de50d88 commit 67021e0

11 files changed

Lines changed: 58 additions & 74 deletions

stan/math/opencl/kernel_generator/broadcast.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ class broadcast_
4242
*/
4343
explicit broadcast_(T&& a) : base(std::forward<T>(a)) {
4444
const char* function = "broadcast";
45-
if (Colwise) {
45+
if constexpr (Colwise) {
4646
check_size_match(function, "Rows of ", "a", a.rows(), "", "", 1);
4747
}
48-
if (Rowwise) {
48+
if constexpr (Rowwise) {
4949
check_size_match(function, "Columns of ", "a", a.cols(), "", "", 1);
5050
}
5151
}
@@ -67,10 +67,10 @@ class broadcast_
6767
*/
6868
inline void modify_argument_indices(std::string& row_index_name,
6969
std::string& col_index_name) const {
70-
if (Colwise) {
70+
if constexpr (Colwise) {
7171
row_index_name = "0";
7272
}
73-
if (Rowwise) {
73+
if constexpr (Rowwise) {
7474
col_index_name = "0";
7575
}
7676
}

stan/math/opencl/kernel_generator/calc_if.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class calc_if_
4444
const std::string& col_index_name,
4545
const bool view_handled,
4646
const std::string& var_name_arg) const {
47-
if (Do_Calculate) {
47+
if constexpr (Do_Calculate) {
4848
var_name_ = var_name_arg;
4949
}
5050
return {};
@@ -70,7 +70,7 @@ class calc_if_
7070
std::unordered_map<const void*, const char*>& generated_all,
7171
name_generator& ng, const std::string& row_index_name,
7272
const std::string& col_index_name, const T_result& result) const {
73-
if (Do_Calculate) {
73+
if constexpr (Do_Calculate) {
7474
return this->template get_arg<0>().get_whole_kernel_parts(
7575
generated, generated_all, ng, row_index_name, col_index_name, result);
7676
} else {
@@ -92,7 +92,7 @@ class calc_if_
9292
std::unordered_map<const void*, const char*>& generated,
9393
std::unordered_map<const void*, const char*>& generated_all,
9494
cl::Kernel& kernel, int& arg_num) const {
95-
if (Do_Calculate) {
95+
if constexpr (Do_Calculate) {
9696
this->template get_arg<0>().set_args(generated, generated_all, kernel,
9797
arg_num);
9898
}

stan/math/opencl/kernel_generator/multi_result_kernel.hpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,12 @@ struct multi_result_kernel_internal {
7676

7777
int expression_rows = expression.rows();
7878
int expression_cols = expression.cols();
79-
if constexpr (is_colwise_reduction<T_current_expression>::value
80-
&& expression_cols == -1) {
79+
if (is_colwise_reduction<T_current_expression>::value && expression_cols == -1) {
8180
expression_cols = n_cols;
8281
expression_rows = expression.thread_rows();
8382
expression_rows = internal::colwise_reduction_wgs_rows(
8483
expression_rows < 0 ? n_rows : expression_rows, expression_cols);
85-
} else if constexpr (is_reduction_2d<T_current_expression>::value
86-
&& expression_cols == -1) {
84+
} else if (is_reduction_2d<T_current_expression>::value && expression_cols == -1) {
8785
expression_rows = internal::colwise_reduction_wgs_rows(n_rows, n_cols);
8886
if (expression_rows == 0) {
8987
expression_cols = 0;

stan/math/opencl/kernel_generator/optional_broadcast.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,10 @@ class optional_broadcast_
6666
kernel_parts res;
6767
res.body
6868
+= type_str<Scalar>() + " " + var_name_ + " = " + var_name_arg + ";\n";
69-
if (Colwise) {
69+
if constexpr (Colwise) {
7070
res.args += "int " + var_name_ + "is_multirow, ";
7171
}
72-
if (Rowwise) {
72+
if constexpr (Rowwise) {
7373
res.args += "int " + var_name_ + "is_multicol, ";
7474
}
7575
return res;
@@ -82,10 +82,10 @@ class optional_broadcast_
8282
*/
8383
inline void modify_argument_indices(std::string& row_idx_name,
8484
std::string& col_idx_name) const {
85-
if (Colwise) {
85+
if constexpr (Colwise) {
8686
row_idx_name = "(" + row_idx_name + " * " + var_name_ + "is_multirow)";
8787
}
88-
if (Rowwise) {
88+
if constexpr (Rowwise) {
8989
col_idx_name = "(" + col_idx_name + " * " + var_name_ + "is_multicol)";
9090
}
9191
}
@@ -109,11 +109,11 @@ class optional_broadcast_
109109
std::unordered_map<const void*, const char*> generated2;
110110
this->template get_arg<0>().set_args(generated2, generated_all, kernel,
111111
arg_num);
112-
if (Colwise) {
112+
if constexpr (Colwise) {
113113
kernel.setArg(arg_num++, static_cast<int>(
114114
this->template get_arg<0>().rows() != 1));
115115
}
116-
if (Rowwise) {
116+
if constexpr (Rowwise) {
117117
kernel.setArg(arg_num++, static_cast<int>(
118118
this->template get_arg<0>().cols() != 1));
119119
}

stan/math/opencl/prim/binomial_lpmf.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,9 @@ return_type_t<T_prob_cl> binomial_lpmf(const T_n_cl& n, const T_N_cl N,
9595
matrix_cl<double> deriv_cl;
9696

9797
constexpr bool need_sums
98-
= is_autodiff_v<T_prob_cl> && is_stan_scalar<T_prob_cl>;
98+
= is_autodiff_v<T_prob_cl> && is_stan_scalar_v<T_prob_cl>;
9999
constexpr bool need_deriv
100-
= is_autodiff_v<T_prob_cl> && !is_stan_scalar<T_prob_cl>;
100+
= is_autodiff_v<T_prob_cl> && !is_stan_scalar_v<T_prob_cl>;
101101

102102
results(check_n_bounded, check_N_nonnegative, check_theta_bounded, logp_cl,
103103
sum_n_cl, sum_N_cl, deriv_cl)
@@ -110,7 +110,7 @@ return_type_t<T_prob_cl> binomial_lpmf(const T_n_cl& n, const T_N_cl N,
110110
auto ops_partials = make_partials_propagator(theta_col);
111111

112112
if constexpr (is_autodiff_v<T_prob_cl>) {
113-
if (need_sums) {
113+
if constexpr (need_sums) {
114114
int sum_n = sum(from_matrix_cl(sum_n_cl));
115115
int sum_N = sum(from_matrix_cl(sum_N_cl));
116116
double theta_dbl = forward_as<double>(theta_val);

stan/math/opencl/prim/categorical_logit_glm_lpmf.hpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -145,16 +145,18 @@ return_type_t<T_x, T_alpha, T_beta> categorical_logit_glm_lpmf(
145145
partials<1>(ops_partials) = rowwise_sum(alpha_derivative_cl);
146146
}
147147
}
148-
if constexpr (is_autodiff_v<T_beta> && N_attributes != 0) {
149-
partials<2>(ops_partials) = transpose(x_val) * neg_softmax_lin_cl;
150-
matrix_cl<double> temp(N_classes, local_size * N_attributes);
151-
try {
152-
opencl_kernels::categorical_logit_glm_beta_derivative(
153-
cl::NDRange(local_size * N_attributes), cl::NDRange(local_size),
154-
forward_as<arena_matrix_cl<double>>(partials<2>(ops_partials)), temp,
155-
y_val_cl, x_val, N_instances, N_attributes, N_classes, is_y_vector);
156-
} catch (const cl::Error& e) {
157-
check_opencl_error(function, e);
148+
if constexpr (is_autodiff_v<T_beta>) {
149+
if (N_attributes != 0) {
150+
partials<2>(ops_partials) = transpose(x_val) * neg_softmax_lin_cl;
151+
matrix_cl<double> temp(N_classes, local_size * N_attributes);
152+
try {
153+
opencl_kernels::categorical_logit_glm_beta_derivative(
154+
cl::NDRange(local_size * N_attributes), cl::NDRange(local_size),
155+
forward_as<arena_matrix_cl<double>>(partials<2>(ops_partials)), temp,
156+
y_val_cl, x_val, N_instances, N_attributes, N_classes, is_y_vector);
157+
} catch (const cl::Error& e) {
158+
check_opencl_error(function, e);
159+
}
158160
}
159161
}
160162
return ops_partials.build(logp);

stan/math/prim/meta/is_constant.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ struct is_constant : bool_constant<std::is_convertible<T, double>::value> {};
3838
template <typename... T>
3939
using is_constant_all = math::conjunction<is_constant<T>...>;
4040

41+
template <typename... T>
42+
inline constexpr bool is_constant_all_v = is_constant_all<std::decay_t<T>...>::value;
43+
4144
/** \ingroup type_trait
4245
* Defines a static member named value and sets it to true
4346
* if the type of the elements in the provided std::vector

stan/math/prim/prob/beta_lpdf.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ return_type_t<T_y, T_scale_succ, T_scale_fail> beta_lpdf(
105105
logp += sum(lgamma(alpha_beta)) * N / max_size(alpha, beta);
106106
if constexpr (is_any_autodiff_v<T_scale_succ, T_scale_fail>) {
107107
const auto& digamma_alpha_beta
108-
= to_ref_if < is_autodiff_v<
109-
T_scale_succ> && is_autodiff_v<T_scale_fail> > (digamma(alpha_beta));
108+
= to_ref_if<is_all_autodiff_v<T_scale_succ, T_scale_fail>>(
109+
digamma(alpha_beta));
110110
if constexpr (is_autodiff_v<T_scale_succ>) {
111111
edge<1>(ops_partials).partials_
112112
= log_y + digamma_alpha_beta - digamma(alpha_val);

stan/math/prim/prob/exp_mod_normal_lpdf.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,8 @@ return_type_t<T_y, T_loc, T_scale, T_inv_scale> exp_mod_normal_lpdf(
9595
T_loc> + is_autodiff_v<T_scale> + is_autodiff_v<T_inv_scale> >= 2>(
9696
-SQRT_TWO_OVER_SQRT_PI * exp_m_sq_inner_term / erfc_calc);
9797
if constexpr (is_any_autodiff_v<T_y, T_loc>) {
98-
const auto& deriv
99-
= to_ref_if < is_autodiff_v<
100-
T_y> && is_autodiff_v<T_loc> > (lambda_val + deriv_logerfc * inv_sigma);
98+
const auto& deriv = to_ref_if<is_all_autodiff_v<T_y, T_loc>>(
99+
lambda_val + deriv_logerfc * inv_sigma);
101100
if constexpr (is_autodiff_v<T_y>) {
102101
partials<0>(ops_partials) = -deriv;
103102
}

stan/math/rev/fun/trace_gen_inv_quad_form_ldlt.hpp

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, LDLT_factor<Ta>& A,
4141
return 0;
4242
}
4343

44-
if constexpr (is_autodiff_v<Ta> && is_autodiff_v<Tb> && is_autodiff_v<Td>) {
44+
if constexpr (is_all_autodiff_v<Ta, Tb, Td>) {
4545
arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
4646
arena_t<promote_scalar_t<var, Tb>> arena_B = B;
4747
arena_t<promote_scalar_t<var, Td>> arena_D = D;
@@ -62,8 +62,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, LDLT_factor<Ta>& A,
6262
});
6363

6464
return res;
65-
} else if constexpr (is_autodiff_v<
66-
Ta> && is_autodiff_v<Tb> && is_constant_v<Td>) {
65+
} else if constexpr (is_all_autodiff_v<Ta, Tb> && is_constant_v<Td>) {
6766
arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
6867
arena_t<promote_scalar_t<var, Tb>> arena_B = B;
6968
arena_t<promote_scalar_t<double, Td>> arena_D = value_of(D);
@@ -80,8 +79,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, LDLT_factor<Ta>& A,
8079
});
8180

8281
return res;
83-
} else if constexpr (is_autodiff_v<
84-
Ta> && is_constant_v<Tb> && is_autodiff_v<Td>) {
82+
} else if constexpr (is_all_autodiff_v<Ta, Td> && is_constant_v<Tb>) {
8583
arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
8684
const auto& B_ref = to_ref(B);
8785
arena_t<promote_scalar_t<var, Td>> arena_D = D;
@@ -100,8 +98,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, LDLT_factor<Ta>& A,
10098
});
10199

102100
return res;
103-
} else if constexpr (is_autodiff_v<
104-
Ta> && is_constant_v<Tb> && is_constant_v<Td>) {
101+
} else if constexpr (is_autodiff_v<Ta> && is_constant_all_v<Tb, Td>) {
105102
arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
106103
const auto& B_ref = to_ref(B);
107104
arena_t<promote_scalar_t<double, Td>> arena_D = value_of(D);
@@ -117,8 +114,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, LDLT_factor<Ta>& A,
117114
});
118115

119116
return res;
120-
} else if constexpr (is_constant_v<
121-
Ta> && is_autodiff_v<Tb> && is_autodiff_v<Td>) {
117+
} else if constexpr (is_constant_v<Ta> && is_all_autodiff_v<Tb, Td>) {
122118
arena_t<promote_scalar_t<var, Tb>> arena_B = B;
123119
arena_t<promote_scalar_t<var, Td>> arena_D = D;
124120
auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
@@ -136,8 +132,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, LDLT_factor<Ta>& A,
136132
});
137133

138134
return res;
139-
} else if constexpr (is_constant_v<
140-
Ta> && is_autodiff_v<Tb> && is_constant_v<Td>) {
135+
} else if constexpr (is_constant_all_v<Ta, Td> && is_autodiff_v<Tb>) {
141136
arena_t<promote_scalar_t<var, Tb>> arena_B = B;
142137
arena_t<promote_scalar_t<double, Td>> arena_D = value_of(D);
143138
auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
@@ -149,8 +144,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, LDLT_factor<Ta>& A,
149144
});
150145

151146
return res;
152-
} else if constexpr (is_constant_v<
153-
Ta> && is_constant_v<Tb> && is_autodiff_v<Td>) {
147+
} else if constexpr (is_constant_all_v<Ta, Tb> && is_autodiff_v<Td>) {
154148
const auto& B_ref = to_ref(B);
155149
arena_t<promote_scalar_t<var, Td>> arena_D = D;
156150
auto BTAsolveB = to_arena(value_of(B_ref).transpose()
@@ -196,7 +190,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, const LDLT_factor<Ta>& A,
196190
return 0;
197191
}
198192

199-
if constexpr (is_autodiff_v<Ta> && is_autodiff_v<Tb> && is_autodiff_v<Td>) {
193+
if constexpr (is_all_autodiff_v<Ta, Tb, Td>) {
200194
arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
201195
arena_t<promote_scalar_t<var, Tb>> arena_B = B;
202196
arena_t<promote_scalar_t<var, Td>> arena_D = D;
@@ -216,8 +210,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, const LDLT_factor<Ta>& A,
216210
});
217211

218212
return res;
219-
} else if constexpr (is_autodiff_v<
220-
Ta> && is_autodiff_v<Tb> && is_constant_v<Td>) {
213+
} else if constexpr (is_all_autodiff_v<Ta, Tb> && is_constant_v<Td>) {
221214
arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
222215
arena_t<promote_scalar_t<var, Tb>> arena_B = B;
223216
arena_t<promote_scalar_t<double, Td>> arena_D = value_of(D);
@@ -235,8 +228,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, const LDLT_factor<Ta>& A,
235228
});
236229

237230
return res;
238-
} else if constexpr (is_autodiff_v<
239-
Ta> && is_constant_v<Tb> && is_autodiff_v<Td>) {
231+
} else if constexpr (is_all_autodiff_v<Ta, Td> && is_constant_v<Tb>) {
240232
arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
241233
const auto& B_ref = to_ref(B);
242234
arena_t<promote_scalar_t<var, Td>> arena_D = D;
@@ -255,8 +247,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, const LDLT_factor<Ta>& A,
255247
});
256248

257249
return res;
258-
} else if constexpr (is_autodiff_v<
259-
Ta> && is_constant_v<Tb> && is_constant_v<Td>) {
250+
} else if constexpr (is_autodiff_v<Ta> && is_constant_all_v<Tb, Td>) {
260251
arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
261252
const auto& B_ref = to_ref(B);
262253
arena_t<promote_scalar_t<double, Td>> arena_D = value_of(D);
@@ -273,8 +264,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, const LDLT_factor<Ta>& A,
273264
});
274265

275266
return res;
276-
} else if constexpr (is_constant_v<
277-
Ta> && is_autodiff_v<Tb> && is_autodiff_v<Td>) {
267+
} else if constexpr (is_constant_v<Ta> && is_all_autodiff_v<Tb, Td>) {
278268
arena_t<promote_scalar_t<var, Tb>> arena_B = B;
279269
arena_t<promote_scalar_t<var, Td>> arena_D = D;
280270
auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
@@ -291,8 +281,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, const LDLT_factor<Ta>& A,
291281
});
292282

293283
return res;
294-
} else if constexpr (is_constant_v<
295-
Ta> && is_autodiff_v<Tb> && is_constant_v<Td>) {
284+
} else if constexpr (is_constant_all_v<Ta, Td> && is_autodiff_v<Tb>) {
296285
arena_t<promote_scalar_t<var, Tb>> arena_B = B;
297286
arena_t<promote_scalar_t<double, Td>> arena_D = value_of(D);
298287
auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
@@ -305,8 +294,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, const LDLT_factor<Ta>& A,
305294
});
306295

307296
return res;
308-
} else if constexpr (is_constant_v<
309-
Ta> && is_constant_v<Tb> && is_autodiff_v<Td>) {
297+
} else if constexpr (is_constant_all_v<Ta, Tb> && is_autodiff_v<Td>) {
310298
const auto& B_ref = to_ref(B);
311299
arena_t<promote_scalar_t<var, Td>> arena_D = D;
312300
auto BTAsolveB = to_arena(value_of(B_ref).transpose()

0 commit comments

Comments
 (0)