Skip to content

Commit a3253a1

Browse files
committed
Expression testing tweaks
1 parent 5d9cacd commit a3253a1

3 files changed

Lines changed: 27 additions & 18 deletions

File tree

stan/math/prim/constraint/stochastic_column_free.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ namespace math {
1717
* @tparam Mat type of the Matrix
1818
* @param y Columnwise stochastic matrix input of dimensionality (N, K)
1919
*/
20-
template <typename Mat, require_eigen_matrix_dynamic_t<Mat>* = nullptr,
21-
require_not_st_var<Mat>* = nullptr>
20+
template <typename Mat, require_eigen_matrix_dynamic_t<Mat>* = nullptr>
2221
inline plain_type_t<Mat> stochastic_column_free(const Mat& y) {
2322
auto&& y_ref = to_ref(y);
2423
const Eigen::Index M = y_ref.cols();

stan/math/prim/constraint/stochastic_row_free.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ namespace math {
1616
* @tparam Mat type of the Matrix
1717
* @param y Rowwise simplex Matrix input of dimensionality (N, K)
1818
*/
19-
template <typename Mat, require_eigen_matrix_dynamic_t<Mat>* = nullptr,
20-
require_not_st_var<Mat>* = nullptr>
19+
template <typename Mat, require_eigen_matrix_dynamic_t<Mat>* = nullptr>
2120
inline plain_type_t<Mat> stochastic_row_free(const Mat& y) {
2221
auto&& y_ref = to_ref(y);
2322
const Eigen::Index N = y_ref.rows();

test/sig_utils.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -139,23 +139,26 @@ def get_cpp_type(stan_type):
139139
# list of functions we do not test. These are mainly functions implemented in compiler
140140
# (not in Stan Math).
141141
ignored = [
142-
"lchoose", # synonym for binomial_coefficient_log
143-
"lmultiply", # synonym for multiply_log
144142
"std_normal_qf", # synonym for inv_Phi
145143
"if_else",
146-
# these are all slight renames compared to stan math
147-
"cholesky_factor_corr_jacobian",
148-
"cholesky_factor_cov_jacobian",
149-
"cholesky_factor_corr_constrain",
150-
"cholesky_factor_cov_constrain",
151-
"lower_bound_jacobian",
152-
"upper_bound_jacobian",
153-
"lower_upper_bound_jacobian",
154-
"lower_bound_constrain",
155-
"upper_bound_constrain",
156-
"lower_upper_bound_constrain",
157144
]
158145

146+
# these are all slight renames compared to stan math
147+
renames = {
148+
"lchoose": "binomial_coefficient_log",
149+
"lmultiply": "multiply_log",
150+
"cholesky_factor_corr_constrain": "cholesky_corr_constrain",
151+
"cholesky_factor_corr_unconstrain": "cholesky_corr_free",
152+
"cholesky_factor_cov_constrain": "cholesky_factor_constrain",
153+
"cholesky_factor_cov_unconstrain": "cholesky_factor_free",
154+
"lower_bound_constrain": "lb_constrain",
155+
"lower_bound_unconstrain": "lb_free",
156+
"upper_bound_constrain": "ub_constrain",
157+
"upper_bound_unconstrain": "ub_free",
158+
"lower_upper_bound_constrain": "lub_constrain",
159+
"lower_upper_bound_unconstrain": "lub_free",
160+
}
161+
159162
# list of function argument indices, for which real valued arguments are not differentiable
160163
# - they need to be double even in autodiff overloads
161164
non_differentiable_args = {
@@ -276,6 +279,14 @@ def get_signatures():
276279

277280
return res + internal_signatures
278281

282+
def handle_rename(function_name):
283+
"""
284+
Replace certain function names with their stan math counterparts
285+
"""
286+
fname = renames.get(function_name, function_name)
287+
if fname.endswith("_unconstrain"):
288+
fname = fname.replace("_unconstrain", "_free")
289+
return fname
279290

280291
def parse_signature(signature):
281292
"""
@@ -292,7 +303,7 @@ def parse_signature(signature):
292303
for i in args
293304
if i.strip()
294305
]
295-
return return_type.strip(), function_name.strip(), args
306+
return return_type.strip(), handle_rename(function_name.strip()), args
296307

297308

298309
def handle_function_list(functions_input):

0 commit comments

Comments
 (0)