Skip to content

Commit 6ee704b

Browse files
committed
update generate_laplace_options to be in
1 parent 6446f9c commit 6ee704b

4 files changed

Lines changed: 75 additions & 49 deletions

File tree

stan/math/mix/functor/laplace_marginal_density_estimator.hpp

Lines changed: 8 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef STAN_MATH_MIX_FUNCTOR_LAPLACE_MARGINAL_DENSITY_ESTIMATOR_HPP
22
#define STAN_MATH_MIX_FUNCTOR_LAPLACE_MARGINAL_DENSITY_ESTIMATOR_HPP
33
#include <stan/math/prim/fun/Eigen.hpp>
4+
#include <stan/math/prim/fun/generate_laplace_options.hpp>
45
#include <stan/math/mix/functor/laplace_likelihood.hpp>
56
#include <stan/math/mix/functor/wolfe_line_search.hpp>
67
#include <stan/math/rev/meta.hpp>
@@ -31,7 +32,7 @@ namespace math {
3132
*/
3233
struct laplace_options_base {
3334
/* Size of the blocks in block diagonal hessian*/
34-
int hessian_block_size{1}; // 0
35+
int hessian_block_size{internal::laplace_default_hessian_block_size}; // 0
3536
/**
3637
* Which linear solver to use inside the Newton step.
3738
*
@@ -47,19 +48,20 @@ struct laplace_options_base {
4748
* `Sigma = K_root * K_root^T` and form `B = I + K_root^T * W * K_root`.
4849
* 3. General LU: form `B = I + Sigma * W` and factorize with LU.
4950
*/
50-
int solver{1}; // 1
51+
int solver{internal::laplace_default_solver}; // 1
5152
/**
5253
* Iterations end when the absolute change in the optimization objective
5354
* is less than this tolerance.
5455
*
5556
* Note: the objective used for convergence is the one optimized by the
5657
* Newton/Wolfe loop (not the final Laplace-corrected log marginal density).
5758
*/
58-
double tolerance{1.49012e-08}; // 2
59+
double tolerance{internal::laplace_default_tolerance}; // 2
5960
/* Maximum number of steps*/
60-
int max_num_steps{500}; // 3
61-
int allow_fallthrough{true}; // 4
62-
laplace_line_search_options line_search; // 5
61+
int max_num_steps{internal::laplace_default_max_num_steps}; // 3
62+
int allow_fallthrough{internal::laplace_default_allow_fallthrough}; // 4
63+
laplace_line_search_options line_search{
64+
internal::laplace_default_max_steps_line_search}; // 5
6365
laplace_options_base() = default;
6466
laplace_options_base(int hessian_block_size_, int solver_, double tolerance_,
6567
int max_num_steps_, bool allow_fallthrough_,
@@ -102,34 +104,6 @@ struct laplace_options<true> : public laplace_options_base {
102104
using laplace_options_default = laplace_options<false>;
103105
using laplace_options_user_supplied = laplace_options<true>;
104106

105-
/**
106-
* User function for generating laplace options tuple
107-
* @param theta_0_size Size of user supplied initial theta
108-
* @return tuple representing laplace options exposed to user.
109-
*/
110-
inline auto generate_laplace_options(int theta_0_size) {
111-
auto ops = laplace_options_default{};
112-
return std::make_tuple(Eigen::VectorXd::Zero(theta_0_size).eval(),
113-
ops.tolerance, ops.max_num_steps, ops.solver,
114-
ops.line_search.max_iterations,
115-
static_cast<int>(ops.allow_fallthrough));
116-
}
117-
118-
/**
119-
* User function for generating laplace options tuple
120-
* @tparam ThetaVec An Eigen vector type for user supplied initial theta
121-
* @param theta_0 User supplied initial theta
122-
* @return tuple representing laplace options exposed to user.
123-
*/
124-
template <typename ThetaVec, require_eigen_t<ThetaVec>* = nullptr>
125-
inline auto generate_laplace_options(ThetaVec&& theta_0) {
126-
auto ops = laplace_options_default{};
127-
return std::make_tuple(std::forward<ThetaVec>(theta_0), ops.tolerance,
128-
ops.max_num_steps, ops.solver,
129-
ops.line_search.max_iterations,
130-
static_cast<int>(ops.allow_fallthrough));
131-
}
132-
133107
namespace internal {
134108

135109
template <typename Options>

stan/math/prim/fun.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@
102102
#include <stan/math/prim/fun/fmod.hpp>
103103
#include <stan/math/prim/fun/gamma_p.hpp>
104104
#include <stan/math/prim/fun/gamma_q.hpp>
105+
#include <stan/math/prim/fun/generate_laplace_options.hpp>
105106
#include <stan/math/prim/fun/generalized_inverse.hpp>
106107
#include <stan/math/prim/fun/get.hpp>
107108
#include <stan/math/prim/fun/get_base1.hpp>
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#ifndef STAN_MATH_PRIM_FUN_GENERATE_LAPLACE_OPTIONS_HPP
2+
#define STAN_MATH_PRIM_FUN_GENERATE_LAPLACE_OPTIONS_HPP
3+
4+
#include <stan/math/prim/fun/Eigen.hpp>
5+
#include <stan/math/prim/meta.hpp>
6+
#include <tuple>
7+
#include <utility>
8+
9+
namespace stan {
10+
namespace math {
11+
12+
namespace internal {
13+
inline constexpr int laplace_default_hessian_block_size = 1;
14+
inline constexpr int laplace_default_solver = 1;
15+
inline constexpr double laplace_default_tolerance = 1.49012e-08;
16+
inline constexpr int laplace_default_max_num_steps = 500;
17+
inline constexpr int laplace_default_allow_fallthrough = 1;
18+
inline constexpr int laplace_default_max_steps_line_search = 1000;
19+
} // namespace internal
20+
21+
/**
22+
* User function for generating laplace options tuple
23+
* @param theta_0_size Size of user supplied initial theta
24+
* @return tuple representing laplace options exposed to user.
25+
*/
26+
inline auto generate_laplace_options(int theta_0_size) {
27+
return std::make_tuple(
28+
Eigen::VectorXd::Zero(theta_0_size).eval(),
29+
internal::laplace_default_tolerance,
30+
internal::laplace_default_max_num_steps,
31+
internal::laplace_default_solver,
32+
internal::laplace_default_max_steps_line_search,
33+
internal::laplace_default_allow_fallthrough);
34+
}
35+
36+
/**
37+
* User function for generating laplace options tuple
38+
* @tparam ThetaVec An Eigen vector type for user supplied initial theta
39+
* @param theta_0 User supplied initial theta
40+
* @return tuple representing laplace options exposed to user.
41+
*/
42+
template <typename ThetaVec, require_eigen_t<ThetaVec>* = nullptr>
43+
inline auto generate_laplace_options(ThetaVec&& theta_0) {
44+
return std::make_tuple(
45+
std::forward<ThetaVec>(theta_0), internal::laplace_default_tolerance,
46+
internal::laplace_default_max_num_steps,
47+
internal::laplace_default_solver,
48+
internal::laplace_default_max_steps_line_search,
49+
internal::laplace_default_allow_fallthrough);
50+
}
51+
52+
} // namespace math
53+
} // namespace stan
54+
55+
#endif
Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,20 @@
1-
#include <stan/math/mix/functor/laplace_marginal_density_estimator.hpp>
1+
#include <stan/math/prim/fun/generate_laplace_options.hpp>
22
#include <gtest/gtest.h>
33

44
namespace stan::math::test {
55

66
TEST(LaplaceMarginalDensityEstimator, GenerateLaplaceOptionsFromSizeDefaults) {
77
constexpr int theta_0_size = 4;
8-
const auto defaults = laplace_options_default{};
9-
108
const auto options = stan::math::generate_laplace_options(theta_0_size);
119
const auto& theta_0 = std::get<0>(options);
1210

1311
EXPECT_EQ(theta_0_size, theta_0.size());
1412
EXPECT_TRUE(theta_0.isApprox(Eigen::VectorXd::Zero(theta_0_size)));
15-
EXPECT_DOUBLE_EQ(defaults.tolerance, std::get<1>(options));
16-
EXPECT_EQ(defaults.max_num_steps, std::get<2>(options));
17-
EXPECT_EQ(defaults.solver, std::get<3>(options));
18-
EXPECT_EQ(defaults.line_search.max_iterations, std::get<4>(options));
19-
EXPECT_EQ(static_cast<int>(defaults.allow_fallthrough), std::get<5>(options));
13+
EXPECT_DOUBLE_EQ(1.49012e-08, std::get<1>(options));
14+
EXPECT_EQ(500, std::get<2>(options));
15+
EXPECT_EQ(1, std::get<3>(options));
16+
EXPECT_EQ(1000, std::get<4>(options));
17+
EXPECT_EQ(1, std::get<5>(options));
2018
}
2119

2220
TEST(LaplaceMarginalDensityEstimator, GenerateLaplaceOptionsFromSizeZero) {
@@ -29,16 +27,14 @@ TEST(LaplaceMarginalDensityEstimator, GenerateLaplaceOptionsFromSizeZero) {
2927
TEST(LaplaceMarginalDensityEstimator, GenerateLaplaceOptionsFromThetaDefaults) {
3028
Eigen::VectorXd theta_0(3);
3129
theta_0 << -1.2, 0.5, 2.3;
32-
const auto defaults = laplace_options_default{};
33-
3430
const auto options = stan::math::generate_laplace_options(theta_0);
3531

3632
EXPECT_TRUE(theta_0.isApprox(std::get<0>(options)));
37-
EXPECT_DOUBLE_EQ(defaults.tolerance, std::get<1>(options));
38-
EXPECT_EQ(defaults.max_num_steps, std::get<2>(options));
39-
EXPECT_EQ(defaults.solver, std::get<3>(options));
40-
EXPECT_EQ(defaults.line_search.max_iterations, std::get<4>(options));
41-
EXPECT_EQ(static_cast<int>(defaults.allow_fallthrough), std::get<5>(options));
33+
EXPECT_DOUBLE_EQ(1.49012e-08, std::get<1>(options));
34+
EXPECT_EQ(500, std::get<2>(options));
35+
EXPECT_EQ(1, std::get<3>(options));
36+
EXPECT_EQ(1000, std::get<4>(options));
37+
EXPECT_EQ(1, std::get<5>(options));
4238
}
4339

4440
} // namespace stan::math::test

0 commit comments

Comments
 (0)