22#define STAN_MATH_MIX_FUNCTOR_LAPLACE_MARGINAL_DENSITY_HPP
33#include < stan/math/prim/fun/Eigen.hpp>
44#include < stan/math/mix/functor/laplace_likelihood.hpp>
5+ #include < test/unit/pretty_print_types.hpp>
56#include < stan/math/rev/meta.hpp>
67#include < stan/math/rev/core.hpp>
78#include < stan/math/rev/fun.hpp>
89#include < stan/math/rev/fun/value_of.hpp>
910#include < stan/math/rev/functor.hpp>
1011#include < stan/math/prim/fun/to_ref.hpp>
1112#include < stan/math/prim/fun/quad_form_diag.hpp>
12- #include < stan/math/prim/functor/iter_tuple_n .hpp>
13-
13+ #include < stan/math/prim/functor/iter_tuple_nested .hpp>
14+ # include < unsupported/Eigen/MatrixFunctions >
1415#include < cmath>
1516
1617/* *
@@ -43,7 +44,7 @@ struct laplace_options {
4344 */
4445 double tolerance{1e-6 };
4546 /* Maximum number of steps*/
46- int64_t max_num_steps{100 };
47+ int max_num_steps{100 };
4748};
4849
4950namespace internal {
@@ -297,7 +298,7 @@ inline void set_zero_adjoint(Output&& output) {
297298 if constexpr (is_all_arithmetic_scalar_v<Output>) {
298299 return ;
299300 } else {
300- return iter_tuple_n (
301+ return iter_tuple_nested (
301302 [](auto && output_i) {
302303 using output_i_t = std::decay_t <decltype (output_i)>;
303304 if constexpr (is_all_arithmetic_scalar_v<output_i_t >) {
@@ -312,9 +313,9 @@ inline void set_zero_adjoint(Output&& output) {
312313 output_i.adj () = 0 ;
313314 } else {
314315 static_assert (
315- 1 ,
316+ sizeof (Output*) == 0 ,
316317 " INTERNAL ERROR:(laplace_marginal_lpdf) set_zero_adjoints was "
317- " not able to deduce the actiopns needed for the given type." );
318+ " not able to deduce the actions needed for the given type." );
318319 }
319320 },
320321 std::forward<Output>(output));
@@ -333,7 +334,7 @@ template <bool ZeroInput = false, typename Output, typename Input,
333334 require_t <is_all_arithmetic_scalar<Output>>* = nullptr ,
334335 require_t <is_all_var_scalar<Input>>* = nullptr >
335336inline void collect_adjoints (Output& output, Input&& input) {
336- return iter_tuple_n (
337+ return iter_tuple_nested (
337338 [](auto && output_i, auto && input_i) {
338339 using output_i_t = std::decay_t <decltype (output_i)>;
339340 if constexpr (is_std_vector_v<output_i_t >) {
@@ -357,7 +358,7 @@ inline void collect_adjoints(Output& output, Input&& input) {
357358 }
358359 } else {
359360 static_assert (
360- 1 ,
361+ sizeof (Output*) == 0 ,
361362 " INTERNAL ERROR:(laplace_marginal_lpdf) collect_adjoints was not "
362363 " able to deduce the actiopns needed for the given type." );
363364 }
@@ -557,7 +558,7 @@ inline auto laplace_marginal_density_est(LLFun&& ll_fun, LLTupleArgs&& ll_args,
557558 std::move (covariance),
558559 std::move (theta),
559560 std::move (W),
560- std::move ( Eigen::MatrixXd (L) ),
561+ Eigen::MatrixXd (L),
561562 std::move (a),
562563 std::move (theta_grad),
563564 Eigen::PartialPivLU<Eigen::MatrixXd>{},
@@ -624,7 +625,7 @@ inline auto laplace_marginal_density_est(LLFun&& ll_fun, LLTupleArgs&& ll_args,
624625 std::move (covariance),
625626 std::move (theta),
626627 std::move (W_r),
627- std::move ( Eigen::MatrixXd (L) ),
628+ Eigen::MatrixXd (L),
628629 std::move (a),
629630 std::move (theta_grad),
630631 Eigen::PartialPivLU<Eigen::MatrixXd>{},
@@ -796,7 +797,7 @@ template <typename Output, typename Input,
796797 require_t <is_all_arithmetic_scalar<Output>>* = nullptr ,
797798 require_t <is_all_arithmetic_scalar<Input>>* = nullptr >
798799inline void collect_adjoints (Output&& output, Input&& input) {
799- return iter_tuple_n (
800+ return iter_tuple_nested (
800801 [](auto && output_i, auto && input_i) {
801802 using output_i_t = std::decay_t <decltype (output_i)>;
802803 if constexpr (is_std_vector_v<output_i_t >) {
@@ -811,7 +812,7 @@ inline void collect_adjoints(Output&& output, Input&& input) {
811812 output_i += input_i;
812813 } else {
813814 static_assert (
814- 1 ,
815+ sizeof (Output*) == 0 ,
815816 " INTERNAL ERROR:(laplace_marginal_lpdf) collect_adjoints was not "
816817 " able to deduce the actiopns needed for the given type." );
817818 }
@@ -837,7 +838,7 @@ template <bool ZeroInput = false, typename Output, typename Input,
837838 require_t <is_all_arithmetic_scalar<Output>>* = nullptr ,
838839 require_t <is_any_var_scalar<Input>>* = nullptr >
839840inline void copy_compute_s2 (Output&& output, Input&& input) {
840- return iter_tuple_n (
841+ return iter_tuple_nested (
841842 [](auto && output_i, auto && input_i) {
842843 using output_i_t = std::decay_t <decltype (output_i)>;
843844 if constexpr (is_std_vector_v<output_i_t >) {
@@ -861,14 +862,21 @@ inline void copy_compute_s2(Output&& output, Input&& input) {
861862 }
862863 } else {
863864 static_assert (
864- 1 ,
865+ sizeof (Output*) == 0 ,
865866 " INTERNAL ERROR:(laplace_marginal_lpdf) copy_compute_s2 was not "
866867 " able to deduce the actiopns needed for the given type." );
867868 }
868869 },
869870 std::forward<Output>(output), std::forward<Input>(input));
870871}
871872
873+ template <typename T>
874+ inline constexpr decltype (auto ) filter_var_scalar_types(T&& t) {
875+ return stan::math::filter_map<is_any_var_scalar>(
876+ [](auto && arg) -> decltype (auto ) {
877+ return std::forward<decltype (arg)>(arg);
878+ }, std::forward<T>(t));
879+ }
872880/* *
873881 * Creates an arena type from the input with initialized with zeros
874882 * @tparam Input Possibly a tuple, std::vector, Eigen type, or scalar
@@ -900,37 +908,6 @@ inline constexpr auto make_zeroed_arena(Input&& input) {
900908 }
901909}
902910
903- /* *
904- * Helper function for printing out adjoints
905- */
906- template <typename Output, require_t <is_any_var_scalar<Output>>* = nullptr >
907- inline void print_adjoint (Output&& output) {
908- if constexpr (is_tuple_v<Output>) {
909- std::cout << " tuple adj\n " ;
910- return stan::math::for_each (
911- [](auto && output_i) { return print_adjoint (output_i); }, output);
912- } else if constexpr (is_std_vector_v<Output>) {
913- if constexpr (is_var_v<value_type_t <Output>>) {
914- Eigen::Map<const Eigen::Matrix<var, -1 , -1 >> map_x (output.data (),
915- output.size ());
916- std::cout << " eigen adj: \n " << map_x.adj () << std::endl;
917- } else {
918- std::cout << " stdvec adjoint\n " ;
919- for (int i = 0 ; i < output.size (); ++i) {
920- print_adjoint (output[i]);
921- }
922- }
923- } else if constexpr (is_eigen_v<Output>) {
924- std::cout << " adj: \n " << output.adj () << std::endl;
925- } else if constexpr (is_stan_scalar_v<Output>) {
926- std::cout << " adj: " << output.adj () << std::endl;
927- } else {
928- static_assert (1 ,
929- " INTERNAL ERROR:(laplace_marginal_lpdf) print_adjoint was "
930- " not able to deduce the actiopns needed for the given type." );
931- }
932- }
933-
934911/* *
935912 * Used in reverse pass to collect adjoints to the output
936913 * @tparam Output A tuple or type where all scalar types are `var` types
@@ -942,7 +919,7 @@ inline void print_adjoint(Output&& output) {
942919template <typename Output, typename Input>
943920inline void collect_adjoints (Output&& output, const vari* ret, Input&& input) {
944921 if constexpr (is_tuple_v<Output>) {
945- static_assert (1 ,
922+ static_assert (!is_tuple_v<Output> ,
946923 " INTERNAL ERROR:(laplace_marginal_lpdf)"
947924 " Accumulate Adjoints called on a tuple, but tuples cannot be "
948925 " on the reverse mode stack!"
@@ -1067,11 +1044,7 @@ inline auto laplace_marginal_density(const LLFun& ll_fun, LLTupleArgs&& ll_args,
10671044 ll_fun, ll_args_copy, value_of (theta_0), covariance_function,
10681045 value_of (covar_args_refs), options, msgs);
10691046 // Return references to var types
1070- auto ll_args_filter = stan::math::filter_map<is_any_var_scalar>(
1071- [](auto && arg) -> decltype (auto ) {
1072- return std::forward<decltype (arg)>(arg);
1073- },
1074- ll_args_copy);
1047+ auto ll_args_filter = internal::filter_var_scalar_types (ll_args_copy);
10751048 stan::math::for_each (
10761049 [](auto && output_i, auto && ll_arg_i) {
10771050 if (is_any_var_scalar_v<decltype (ll_arg_i)>) {
@@ -1171,11 +1144,7 @@ inline auto laplace_marginal_density(const LLFun& ll_fun, LLTupleArgs&& ll_args,
11711144 K_var.adj ().array () += vi.adj () * K_adj_arena.array ();
11721145 });
11731146 grad (Z.vi_ );
1174- auto covar_args_filter = stan::math::filter_map<is_any_var_scalar>(
1175- [](auto && arg) -> decltype (auto ) {
1176- return std::forward<decltype (arg)>(arg);
1177- },
1178- covar_args_copy);
1147+ auto covar_args_filter = internal::filter_var_scalar_types (covar_args_copy);
11791148 internal::collect_adjoints (covar_args_adj, covar_args_filter);
11801149 }();
11811150 }
@@ -1194,14 +1163,12 @@ inline auto laplace_marginal_density(const LLFun& ll_fun, LLTupleArgs&& ll_args,
11941163 }
11951164 var ret (lmd);
11961165 if constexpr (is_any_var_scalar_v<CovarArgs>) {
1197- auto covar_args_filter = stan::math::filter_map<is_any_var_scalar>(
1198- [](auto && arg) -> decltype (auto ) { return arg; }, covar_args_refs);
1166+ auto covar_args_filter = internal::filter_var_scalar_types (covar_args_refs);
11991167 internal::reverse_pass_collect_adjoints (ret, covar_args_filter,
12001168 covar_args_adj);
12011169 }
12021170 if constexpr (ll_args_contain_var) {
1203- auto ll_args_filter = stan::math::filter_map<is_any_var_scalar>(
1204- [](auto && arg) -> decltype (auto ) { return arg; }, ll_args_refs);
1171+ auto ll_args_filter = internal::filter_var_scalar_types (ll_args_refs);
12051172 internal::reverse_pass_collect_adjoints (ret, ll_args_filter, partial_parm);
12061173 }
12071174 return ret;
0 commit comments