@@ -61,13 +61,10 @@ def get_cpp_type(stan_type):
6161
6262
6363simplex = "simplex"
64+ stochastic = "stochastic_matrix"
6465pos_definite = "positive_definite_matrix"
6566scalar_return_type = "scalar_return_type"
6667
67- make_special_arg_values = {
68- simplex : "make_simplex" ,
69- pos_definite : "make_pos_definite_matrix" ,
70- }
7168
7269# list of function arguments that need special scalar values.
7370# None means to use the default argument value.
@@ -79,6 +76,9 @@ def get_cpp_type(stan_type):
7976 "log1m_exp" : [- 0.6 ],
8077 "categorical_rng" : [simplex , None ],
8178 "categorical_lpmf" : [None , simplex ],
79+ "corr_matrix_constrain" : [None , 2 ],
80+ "corr_matrix_free" : [1 ],
81+ "cov_matrix_constrain" : [None , 1 ],
8282 "cholesky_decompose" : [pos_definite , None ],
8383 "cholesky_corr_constrain" : [None , 2 ],
8484 "cholesky_factor_constrain" : [None , 1 ,1 ],
@@ -93,6 +93,8 @@ def get_cpp_type(stan_type):
9393 "lkj_corr_lpdf" : [1 , None ],
9494 "log_diff_exp" : [3 , None ],
9595 "log_inv_logit_diff" : [1.2 , 0.4 ],
96+ "lb_constrain" : [None , 0.1 ],
97+ "lb_free" : [0.5 , 0.1 ],
9698 "lub_constrain" : [None , 0.1 , 0.9 ],
9799 "lub_free" : [0.5 , 0.1 , 0.9 ],
98100 "mdivide_left_spd" : [pos_definite , None ],
@@ -123,12 +125,17 @@ def get_cpp_type(stan_type):
123125 "offset_multiplier_free" : [10 , None , None ],
124126 "simplex_constrain" : [None , scalar_return_type ],
125127 "simplex_free" : [simplex ],
128+ "sum_to_zero_free" : [0 ],
126129 "std_normal_log_qf" : [- 0.1 ],
130+ "stochastic_column_free" : [stochastic ],
131+ "stochastic_row_free" : [stochastic ],
127132 "student_t_cdf" : [0.8 , None , 0.4 , None ],
128133 "student_t_cdf_log" : [0.8 , None , 0.4 , None ],
129134 "student_t_ccdf_log" : [0.8 , None , 0.4 , None ],
130135 "student_t_lccdf" : [0.8 , None , 0.4 , None ],
131136 "student_t_lcdf" : [0.8 , None , 0.4 , None ],
137+ "ub_constrain" : [None , 0.9 ],
138+ "ub_free" : [0.5 , 0.9 ],
132139 "unit_vector_constrain" : [None , scalar_return_type ],
133140 "unit_vector_free" : [simplex ],
134141 "uniform_cdf" : [None , 0.2 , 0.9 ],
@@ -322,6 +329,7 @@ def handle_function_list(functions_input):
322329 function_names = []
323330 function_signatures = []
324331 for f in functions_input :
332+ f = f .strip ()
325333 if ("." in f ) or ("/" in f ) or ("\\ " in f ):
326334 with open (f ) as fh :
327335 functions_input .extend (parse_signature_file (fh ))
0 commit comments