Skip to content

Commit ac977e4

Browse files
authored
Merge pull request #3140 from stan-dev/std_normal-vec
Simplify vectorisation of std_normal_lpdf
2 parents 42d94c4 + a929e7e commit ac977e4

3 files changed

Lines changed: 21 additions & 14 deletions

File tree

stan/math/prim/fun/dot_self.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010
namespace stan {
1111
namespace math {
1212

13+
template <typename T, require_stan_scalar_t<T>* = nullptr>
14+
inline T dot_self(const T& x) {
15+
return x * x;
16+
}
17+
1318
inline double dot_self(const std::vector<double>& x) {
1419
double sum = 0.0;
1520
for (double i : x) {

stan/math/prim/prob/std_normal_lpdf.hpp

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/err.hpp>
66
#include <stan/math/prim/fun/constants.hpp>
7-
#include <stan/math/prim/fun/scalar_seq_view.hpp>
87
#include <stan/math/prim/fun/size.hpp>
98
#include <stan/math/prim/fun/size_zero.hpp>
10-
#include <stan/math/prim/fun/value_of.hpp>
9+
#include <stan/math/prim/fun/dot_self.hpp>
10+
#include <stan/math/prim/fun/as_value_column_vector_or_scalar.hpp>
1111
#include <stan/math/prim/functor/partials_propagator.hpp>
1212

1313
namespace stan {
@@ -43,22 +43,16 @@ return_type_t<T_y> std_normal_lpdf(const T_y& y) {
4343
return 0.0;
4444
}
4545

46-
T_partials_return logp(0.0);
46+
const auto& y_val = as_value_column_vector_or_scalar(y_ref);
47+
T_partials_return logp = -dot_self(y_val) / 2.0;
4748
auto ops_partials = make_partials_propagator(y_ref);
4849

49-
scalar_seq_view<T_y_ref> y_vec(y_ref);
50-
size_t N = stan::math::size(y);
51-
52-
for (size_t n = 0; n < N; n++) {
53-
const T_partials_return y_val = y_vec.val(n);
54-
logp += y_val * y_val;
55-
if (!is_constant_all<T_y>::value) {
56-
partials<0>(ops_partials)[n] -= y_val;
57-
}
50+
if (!is_constant_all<T_y>::value) {
51+
partials<0>(ops_partials) = -y_val;
5852
}
59-
logp *= -0.5;
53+
6054
if (include_summand<propto>::value) {
61-
logp += NEG_LOG_SQRT_TWO_PI * N;
55+
logp += NEG_LOG_SQRT_TWO_PI * math::size(y);
6256
}
6357

6458
return ops_partials.build(logp);

test/unit/math/mix/prob/std_normal_test.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,12 @@ TEST_F(AgradRev, mathMixScalFun_std_normal) {
77
stan::test::expect_ad(f, -0.3);
88
stan::test::expect_ad(f, 0.0);
99
stan::test::expect_ad(f, 1.7);
10+
11+
Eigen::VectorXd x(3);
12+
x << -0.3, 0.0, 1.7;
13+
std::vector<double> x2{0.0, 1.7};
14+
15+
stan::test::expect_ad(f, x);
16+
stan::test::expect_ad(f, x.transpose().eval());
17+
stan::test::expect_ad(f, x2);
1018
}

0 commit comments

Comments
 (0)