Skip to content

Commit 9e7af77

Browse files
committed
Use ref_type instead of constant check
1 parent b20dd5b commit 9e7af77

2 files changed

Lines changed: 31 additions & 28 deletions

File tree

stan/math/prim/prob/wiener5_lpdf.hpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -679,12 +679,12 @@ inline auto wiener_lpdf(const T_y& y, const T_a& a, const T_t0& t0,
679679
if (!include_summand<propto, T_y, T_a, T_t0, T_w, T_v, T_sv>::value) {
680680
return ret_t(0.0);
681681
}
682-
using T_y_ref = ref_type_if_t<!is_constant<T_y>::value, T_y>;
683-
using T_a_ref = ref_type_if_t<!is_constant<T_a>::value, T_a>;
684-
using T_t0_ref = ref_type_if_t<!is_constant<T_t0>::value, T_t0>;
685-
using T_w_ref = ref_type_if_t<!is_constant<T_w>::value, T_w>;
686-
using T_v_ref = ref_type_if_t<!is_constant<T_v>::value, T_v>;
687-
using T_sv_ref = ref_type_if_t<!is_constant<T_sv>::value, T_sv>;
682+
using T_y_ref = ref_type_t<T_y>;
683+
using T_a_ref = ref_type_t<T_a>;
684+
using T_t0_ref = ref_type_t<T_t0>;
685+
using T_w_ref = ref_type_t<T_w>;
686+
using T_v_ref = ref_type_t<T_v>;
687+
using T_sv_ref = ref_type_t<T_sv>;
688688

689689
static constexpr const char* function_name = "wiener5_lpdf";
690690

@@ -725,12 +725,12 @@ inline auto wiener_lpdf(const T_y& y, const T_a& a, const T_t0& t0,
725725
return ret_t(0.0);
726726
}
727727

728-
scalar_seq_view<decltype(y_val)> y_vec(y_val);
729-
scalar_seq_view<decltype(a_val)> a_vec(a_val);
730-
scalar_seq_view<decltype(v_val)> v_vec(v_val);
731-
scalar_seq_view<decltype(w_val)> w_vec(w_val);
732-
scalar_seq_view<decltype(t0_val)> t0_vec(t0_val);
733-
scalar_seq_view<decltype(sv_val)> sv_vec(sv_val);
728+
scalar_seq_view<T_y_ref> y_vec(y_ref);
729+
scalar_seq_view<T_a_ref> a_vec(a_ref);
730+
scalar_seq_view<T_t0_ref> t0_vec(t0_ref);
731+
scalar_seq_view<T_w_ref> w_vec(w_ref);
732+
scalar_seq_view<T_v_ref> v_vec(v_ref);
733+
scalar_seq_view<T_sv_ref> sv_vec(sv_ref);
734734
const size_t N_y_t0 = max_size(y, t0);
735735

736736
for (size_t i = 0; i < N_y_t0; ++i) {

stan/math/prim/prob/wiener_full_lpdf.hpp

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -327,14 +327,14 @@ inline auto wiener_lpdf(const T_y& y, const T_a& a, const T_t0& t0,
327327
return ret_t(0);
328328
}
329329

330-
using T_y_ref = ref_type_if_t<!is_constant<T_y>::value, T_y>;
331-
using T_a_ref = ref_type_if_t<!is_constant<T_a>::value, T_a>;
332-
using T_v_ref = ref_type_if_t<!is_constant<T_v>::value, T_v>;
333-
using T_w_ref = ref_type_if_t<!is_constant<T_w>::value, T_w>;
334-
using T_t0_ref = ref_type_if_t<!is_constant<T_t0>::value, T_t0>;
335-
using T_sv_ref = ref_type_if_t<!is_constant<T_sv>::value, T_sv>;
336-
using T_sw_ref = ref_type_if_t<!is_constant<T_sw>::value, T_sw>;
337-
using T_st0_ref = ref_type_if_t<!is_constant<T_st0>::value, T_st0>;
330+
using T_y_ref = ref_type_t<T_y>;
331+
using T_a_ref = ref_type_t<T_a>;
332+
using T_v_ref = ref_type_t<T_v>;
333+
using T_w_ref = ref_type_t<T_w>;
334+
using T_t0_ref = ref_type_t<T_t0>;
335+
using T_sv_ref = ref_type_t<T_sv>;
336+
using T_sw_ref = ref_type_t<T_sw>;
337+
using T_st0_ref = ref_type_t<T_st0>;
338338

339339
using T_partials_return
340340
= partials_return_t<T_y, T_a, T_t0, T_w, T_v, T_sv, T_sw, T_st0>;
@@ -385,14 +385,14 @@ inline auto wiener_lpdf(const T_y& y, const T_a& a, const T_t0& t0,
385385
if (N == 0) {
386386
return ret_t(0);
387387
}
388-
scalar_seq_view<decltype(y_val)> y_vec(y_val);
389-
scalar_seq_view<decltype(a_val)> a_vec(a_val);
390-
scalar_seq_view<decltype(v_val)> v_vec(v_val);
391-
scalar_seq_view<decltype(w_val)> w_vec(w_val);
392-
scalar_seq_view<decltype(t0_val)> t0_vec(t0_val);
393-
scalar_seq_view<decltype(sv_val)> sv_vec(sv_val);
394-
scalar_seq_view<decltype(sw_val)> sw_vec(sw_val);
395-
scalar_seq_view<decltype(st0_val)> st0_vec(st0_val);
388+
scalar_seq_view<T_y_ref> y_vec(y_ref);
389+
scalar_seq_view<T_a_ref> a_vec(a_ref);
390+
scalar_seq_view<T_v_ref> v_vec(v_ref);
391+
scalar_seq_view<T_w_ref> w_vec(w_ref);
392+
scalar_seq_view<T_t0_ref> t0_vec(t0_ref);
393+
scalar_seq_view<T_sv_ref> sv_vec(sv_ref);
394+
scalar_seq_view<T_sw_ref> sw_vec(sw_ref);
395+
scalar_seq_view<T_st0_ref> st0_vec(st0_ref);
396396
const size_t N_y_t0 = max_size(y, t0, st0);
397397

398398
for (size_t i = 0; i < N_y_t0; ++i) {
@@ -449,6 +449,9 @@ inline auto wiener_lpdf(const T_y& y, const T_a& a, const T_t0& t0,
449449
// calculate density and partials
450450
for (size_t i = 0; i < N; i++) {
451451
if (sw_vec[i] == 0 && st0_vec[i] == 0) {
452+
// note: because we're delegating to wiener5_lpdf,
453+
// we need to make sure is_constant is consistent between
454+
// our inputs and these
452455
result += wiener_lpdf<propto>(y_vec[i], a_vec[i], t0_vec[i], w_vec[i],
453456
v_vec[i], sv_vec[i], precision_derivatives);
454457
continue;

0 commit comments

Comments
 (0)