Skip to content

Commit 26829c1

Browse files
committed
More expression test set-up
1 parent 40707de commit 26829c1

3 files changed

Lines changed: 22 additions & 4 deletions

File tree

test/code_generator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,10 @@ def build_arguments(self, signature_parser, arg_overloads, size):
129129
)
130130
elif inner_type == "(real, vector, ostream_ptr, vector) => vector":
131131
arg = statement_types.OdeFunctorVariable("functor" + suffix)
132+
elif inner_type == "stochastic_matrix":
133+
arg = statement_types.StochasticMatrixVariable(
134+
overload, "stochastic_matrix" + suffix, size, value
135+
)
132136
else:
133137
raise Exception("Inner type " + inner_type + " not supported")
134138

test/sig_utils.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,10 @@ def get_cpp_type(stan_type):
6161

6262

6363
simplex = "simplex"
64+
stochastic = "stochastic_matrix"
6465
pos_definite = "positive_definite_matrix"
6566
scalar_return_type = "scalar_return_type"
6667

67-
make_special_arg_values = {
68-
simplex: "make_simplex",
69-
pos_definite: "make_pos_definite_matrix",
70-
}
7168

7269
# list of function arguments that need special scalar values.
7370
# None means to use the default argument value.
@@ -79,6 +76,9 @@ def get_cpp_type(stan_type):
7976
"log1m_exp": [-0.6],
8077
"categorical_rng": [simplex, None],
8178
"categorical_lpmf": [None, simplex],
79+
"corr_matrix_constrain": [None, 2],
80+
"corr_matrix_free": [1],
81+
"cov_matrix_constrain": [None, 1],
8282
"cholesky_decompose": [pos_definite, None],
8383
"cholesky_corr_constrain": [None, 2],
8484
"cholesky_factor_constrain": [None, 1,1],
@@ -93,6 +93,8 @@ def get_cpp_type(stan_type):
9393
"lkj_corr_lpdf": [1, None],
9494
"log_diff_exp": [3, None],
9595
"log_inv_logit_diff": [1.2, 0.4],
96+
"lb_constrain": [None, 0.1],
97+
"lb_free": [0.5, 0.1],
9698
"lub_constrain": [None, 0.1, 0.9],
9799
"lub_free": [0.5, 0.1, 0.9],
98100
"mdivide_left_spd": [pos_definite, None],
@@ -123,12 +125,17 @@ def get_cpp_type(stan_type):
123125
"offset_multiplier_free": [10, None, None],
124126
"simplex_constrain": [None, scalar_return_type],
125127
"simplex_free": [simplex],
128+
"sum_to_zero_free": [0],
126129
"std_normal_log_qf": [-0.1],
130+
"stochastic_column_free": [stochastic],
131+
"stochastic_row_free": [stochastic],
127132
"student_t_cdf": [0.8, None, 0.4, None],
128133
"student_t_cdf_log": [0.8, None, 0.4, None],
129134
"student_t_ccdf_log": [0.8, None, 0.4, None],
130135
"student_t_lccdf": [0.8, None, 0.4, None],
131136
"student_t_lcdf": [0.8, None, 0.4, None],
137+
"ub_constrain": [None, 0.9],
138+
"ub_free": [0.5, 0.9],
132139
"unit_vector_constrain": [None, scalar_return_type],
133140
"unit_vector_free": [simplex],
134141
"uniform_cdf": [None, 0.2, 0.9],
@@ -322,6 +329,7 @@ def handle_function_list(functions_input):
322329
function_names = []
323330
function_signatures = []
324331
for f in functions_input:
332+
f = f.strip()
325333
if ("." in f) or ("/" in f) or ("\\" in f):
326334
with open(f) as fh:
327335
functions_input.extend(parse_signature_file(fh))

test/statement_types.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numbers
22
import os
3+
from re import S
34

45
from sig_utils import overload_scalar, get_cpp_type, arg_types
56

@@ -209,6 +210,11 @@ def cpp(self):
209210
)
210211
)
211212

213+
class StochasticMatrixVariable(SimplexVariable):
214+
# works for size one
215+
def __init__(self, overload, name, size, value=None):
216+
super().__init__(overload, name, size, value)
217+
self.stan_arg = "matrix"
212218

213219
class PositiveDefiniteMatrixVariable(CppStatement):
214220
"""Represents a positive definite matrix variable"""

0 commit comments

Comments
 (0)