Skip to content

Commit b592e63

Browse files
committed
remove all uses for forward_as
1 parent 734df39 commit b592e63

1 file changed

Lines changed: 66 additions & 68 deletions

File tree

test/unit/math/test_ad_matvar.hpp

Lines changed: 66 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -430,79 +430,77 @@ void expect_ad_matvar_impl(const ad_tolerances& tols, const F& f,
430430
if (!stan::math::disjunction<is_var_matrix<Types>...>::value) {
431431
FAIL() << "expect_ad_matvar requires at least one varmat input!"
432432
<< std::endl;
433-
}
433+
} else {
434434

435-
if (!stan::math::disjunction<is_var<scalar_type_t<Types>>...>::value) {
435+
} if constexpr (!stan::math::disjunction<is_var<scalar_type_t<Types>>...>::value) {
436436
FAIL() << "expect_ad_matvar requires at least one autodiff input!"
437437
<< std::endl;
438-
}
439-
440-
auto A_mv_tuple = std::make_tuple(make_matvar_compatible<Types>(x)...);
441-
auto A_vm_tuple = std::make_tuple(make_varmat_compatible<Types>(x)...);
442-
443-
bool any_varmat = stan::math::apply(
444-
[](const auto&... args) {
445-
return stan::math::disjunction<is_var_matrix<decltype(args)>...>::value
446-
|| stan::math::disjunction<stan::math::conjunction<
447-
is_std_vector<decltype(args)>,
448-
is_var_matrix<value_type_t<decltype(args)>>>...>::value;
449-
},
450-
A_vm_tuple);
451-
452-
if (!any_varmat) {
453-
SUCCEED(); // If no varmats are created, skip this test
454-
return;
455-
}
456-
457-
using T_mv_ret = plain_type_t<decltype(stan::math::apply(f, A_mv_tuple))>;
458-
using T_vm_ret = plain_type_t<decltype(stan::math::apply(f, A_vm_tuple))>;
459-
460-
T_mv_ret A_mv_ret;
461-
T_vm_ret A_vm_ret;
462-
463-
if (is_var_matrix<T_mv_ret>::value
464-
|| (is_std_vector<T_mv_ret>::value
465-
&& is_var_matrix<value_type_t<T_mv_ret>>::value)) {
466-
FAIL() << "A function with matvar inputs should not return "
467-
<< type_name<T_mv_ret>() << std::endl;
468-
}
469-
470-
if (is_eigen<T_vm_ret>::value
471-
|| (is_std_vector<T_vm_ret>::value
472-
&& is_eigen<value_type_t<T_vm_ret>>::value)) {
473-
FAIL() << "A function with varmat inputs should not return "
474-
<< type_name<T_vm_ret>() << std::endl;
475-
}
476-
477-
// If one throws, the other should throw as well
478-
try {
479-
A_mv_ret = stan::math::apply(f, A_mv_tuple);
480-
} catch (...) {
481-
try {
482-
stan::math::apply(f, A_vm_tuple);
483-
FAIL() << "`Eigen::Matrix<var, R, C>` function throws and "
484-
"`var_value<Eigen::Matrix<double, R, C>>` does not";
485-
} catch (...) {
486-
SUCCEED();
487-
return;
488-
}
489-
}
490-
try {
491-
A_vm_ret = stan::math::apply(f, A_vm_tuple);
492-
} catch (...) {
493-
try {
494-
stan::math::apply(f, A_mv_tuple);
495-
FAIL() << "`var_value<Eigen::Matrix<double, R, C>>` function throws and "
496-
"`Eigen::Matrix<var, R, C>` does not";
497-
} catch (...) {
498-
SUCCEED();
438+
} else {
439+
auto A_mv_tuple = std::make_tuple(make_matvar_compatible<Types>(x)...);
440+
auto A_vm_tuple = std::make_tuple(make_varmat_compatible<Types>(x)...);
441+
442+
constexpr bool any_varmat = stan::math::apply(
443+
[](const auto&... args) {
444+
return stan::math::disjunction<is_var_matrix<decltype(args)>...>::value
445+
|| stan::math::disjunction<stan::math::conjunction<
446+
is_std_vector<decltype(args)>,
447+
is_var_matrix<value_type_t<decltype(args)>>>...>::value;
448+
},
449+
A_vm_tuple);
450+
451+
if constexpr (!any_varmat) {
452+
SUCCEED(); // If no varmats are created, skip this test
499453
return;
454+
} else {
455+
using T_mv_ret = plain_type_t<decltype(stan::math::apply(f, A_mv_tuple))>;
456+
using T_vm_ret = plain_type_t<decltype(stan::math::apply(f, A_vm_tuple))>;
457+
458+
T_mv_ret A_mv_ret;
459+
T_vm_ret A_vm_ret;
460+
461+
if (is_var_matrix<T_mv_ret>::value
462+
|| (is_std_vector<T_mv_ret>::value
463+
&& is_var_matrix<value_type_t<T_mv_ret>>::value)) {
464+
FAIL() << "A function with matvar inputs should not return "
465+
<< type_name<T_mv_ret>() << std::endl;
466+
}
467+
468+
if (is_eigen<T_vm_ret>::value
469+
|| (is_std_vector<T_vm_ret>::value
470+
&& is_eigen<value_type_t<T_vm_ret>>::value)) {
471+
FAIL() << "A function with varmat inputs should not return "
472+
<< type_name<T_vm_ret>() << std::endl;
473+
}
474+
475+
// If one throws, the other should throw as well
476+
try {
477+
A_mv_ret = stan::math::apply(f, A_mv_tuple);
478+
} catch (...) {
479+
try {
480+
stan::math::apply(f, A_vm_tuple);
481+
FAIL() << "`Eigen::Matrix<var, R, C>` function throws and "
482+
"`var_value<Eigen::Matrix<double, R, C>>` does not";
483+
} catch (...) {
484+
SUCCEED();
485+
return;
486+
}
487+
}
488+
try {
489+
A_vm_ret = stan::math::apply(f, A_vm_tuple);
490+
} catch (...) {
491+
try {
492+
stan::math::apply(f, A_mv_tuple);
493+
FAIL() << "`var_value<Eigen::Matrix<double, R, C>>` function throws and "
494+
"`Eigen::Matrix<var, R, C>` does not";
495+
} catch (...) {
496+
SUCCEED();
497+
return;
498+
}
499+
}
500+
test_matvar_gradient(tols, A_mv_ret, A_vm_ret, A_mv_tuple, A_vm_tuple);
501+
stan::math::recover_memory();
500502
}
501503
}
502-
503-
test_matvar_gradient(tols, A_mv_ret, A_vm_ret, A_mv_tuple, A_vm_tuple);
504-
505-
stan::math::recover_memory();
506504
}
507505

508506
/** @name expect_ad_matvar
@@ -550,7 +548,7 @@ void expect_ad_matvar(const ad_tolerances& tols, const F& f, const EigMat& x) {
550548
*/
551549
template <typename F, typename EigMat>
552550
void expect_ad_matvar(const F& f, const EigMat& x) {
553-
ad_tolerances tols;
551+
constexpr ad_tolerances tols;
554552

555553
expect_ad_matvar(tols, f, x);
556554
}

0 commit comments

Comments
 (0)