|
4 | 4 | #include <stan/math/prim/meta.hpp> |
5 | 5 | #include <stan/math/prim/err.hpp> |
6 | 6 | #include <stan/math/prim/fun/constants.hpp> |
7 | | -#include <stan/math/prim/fun/scalar_seq_view.hpp> |
8 | 7 | #include <stan/math/prim/fun/size.hpp> |
9 | 8 | #include <stan/math/prim/fun/size_zero.hpp> |
10 | | -#include <stan/math/prim/fun/value_of.hpp> |
| 9 | +#include <stan/math/prim/fun/dot_self.hpp> |
| 10 | +#include <stan/math/prim/fun/as_value_column_vector_or_scalar.hpp> |
11 | 11 | #include <stan/math/prim/functor/partials_propagator.hpp> |
12 | 12 |
|
13 | 13 | namespace stan { |
@@ -43,22 +43,16 @@ return_type_t<T_y> std_normal_lpdf(const T_y& y) { |
43 | 43 | return 0.0; |
44 | 44 | } |
45 | 45 |
|
46 | | - T_partials_return logp(0.0); |
| 46 | + const auto& y_val = as_value_column_vector_or_scalar(y_ref); |
| 47 | + T_partials_return logp = -dot_self(y_val) / 2.0; |
47 | 48 | auto ops_partials = make_partials_propagator(y_ref); |
48 | 49 |
|
49 | | - scalar_seq_view<T_y_ref> y_vec(y_ref); |
50 | | - size_t N = stan::math::size(y); |
51 | | - |
52 | | - for (size_t n = 0; n < N; n++) { |
53 | | - const T_partials_return y_val = y_vec.val(n); |
54 | | - logp += y_val * y_val; |
55 | | - if (!is_constant_all<T_y>::value) { |
56 | | - partials<0>(ops_partials)[n] -= y_val; |
57 | | - } |
| 50 | + if (!is_constant_all<T_y>::value) { |
| 51 | + partials<0>(ops_partials) = -y_val; |
58 | 52 | } |
59 | | - logp *= -0.5; |
| 53 | + |
60 | 54 | if (include_summand<propto>::value) { |
61 | | - logp += NEG_LOG_SQRT_TWO_PI * N; |
| 55 | + logp += NEG_LOG_SQRT_TWO_PI * math::size(y); |
62 | 56 | } |
63 | 57 |
|
64 | 58 | return ops_partials.build(logp); |
|
0 commit comments