Skip to content

Commit 4f97d77

Browse files
committed
update with brians review
1 parent 1913a45 commit 4f97d77

12 files changed

Lines changed: 25 additions & 16 deletions

File tree

stan/math/fwd/fun/log_softmax.hpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <stan/math/fwd/meta.hpp>
77
#include <stan/math/fwd/fun/softmax.hpp>
88
#include <stan/math/prim/fun/log_softmax.hpp>
9+
#include <stan/math/prim/fun/to_ref.hpp>
910

1011
namespace stan {
1112
namespace math {
@@ -20,23 +21,23 @@ namespace math {
2021
*/
2122
template <typename T, require_vector_st<is_fvar, T>* = nullptr>
2223
inline auto log_softmax(T&& x) {
23-
return apply_vector_unary<T>::apply(std::forward<T>(x), [&](auto&& alpha) {
24+
return apply_vector_unary<T>::apply(std::forward<T>(x), [](auto&& alpha) {
2425
using T_alpha = decltype(alpha);
2526
using T_fvar = value_type_t<T_alpha>;
2627
using T_fvar_inner = typename T_fvar::Scalar;
2728

28-
const Eigen::Ref<const plain_type_t<T_alpha>>& alpha_ref = alpha;
29+
auto&& alpha_ref = to_ref(std::forward<decltype(alpha)>(alpha));
2930
Eigen::Matrix<T_fvar_inner, -1, 1> alpha_t = alpha_ref.val();
3031
Eigen::Matrix<T_fvar_inner, -1, 1> softmax_alpha_t = softmax(alpha_t);
3132

32-
Eigen::Matrix<T_fvar, -1, 1> log_softmax_alpha(alpha.size());
33+
Eigen::Matrix<T_fvar, -1, 1> log_softmax_alpha(alpha_ref.size());
3334
log_softmax_alpha.val() = log_softmax(alpha_t);
3435
log_softmax_alpha.d().setZero();
3536

36-
for (int m = 0; m < alpha.size(); ++m) {
37+
for (int m = 0; m < alpha_ref.size(); ++m) {
3738
T_fvar_inner negative_alpha_m_d_times_softmax_alpha_t_m
3839
= -alpha_ref.coeff(m).d_ * softmax_alpha_t(m);
39-
for (int k = 0; k < alpha.size(); ++k) {
40+
for (int k = 0; k < alpha_ref.size(); ++k) {
4041
if (m == k) {
4142
log_softmax_alpha(k).d_
4243
+= alpha_ref.coeff(m).d_

stan/math/fwd/fun/log_sum_exp.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ inline fvar<T> log_sum_exp(const fvar<T>& x1, double x2) {
5252
template <typename T, require_container_st<is_fvar, T>* = nullptr>
5353
inline auto log_sum_exp(T&& x) {
5454
return apply_vector_unary<ref_type_t<T>>::reduce(
55-
to_ref(std::forward<T>(x)), [&](auto&& v) {
55+
to_ref(std::forward<T>(x)), [](auto&& v) {
5656
using T_fvar_inner = typename value_type_t<decltype(v)>::Scalar;
5757
using mat_type = Eigen::Matrix<T_fvar_inner, -1, -1>;
5858
mat_type vals = v.val();

stan/math/fwd/fun/norm1.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace math {
2222
template <typename Container, require_eigen_vt<is_fvar, Container>* = nullptr>
2323
inline auto norm1(Container&& x) {
2424
return apply_vector_unary<ref_type_t<Container>>::reduce(
25-
to_ref(std::forward<Container>(x)), [&](auto&& v) {
25+
to_ref(std::forward<Container>(x)), [](auto&& v) {
2626
using T_fvar_inner = typename value_type_t<decltype(v)>::Scalar;
2727
return fvar<T_fvar_inner>(norm1(v.val()),
2828
v.d().cwiseProduct(sign(v.val())).sum());

stan/math/fwd/fun/norm2.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace math {
2121
template <typename Container, require_eigen_vt<is_fvar, Container>* = nullptr>
2222
inline auto norm2(Container&& x) {
2323
return apply_vector_unary<ref_type_t<Container>>::reduce(
24-
to_ref(std::forward<Container>(x)), [&](auto&& v) {
24+
to_ref(std::forward<Container>(x)), [](auto&& v) {
2525
using T_fvar_inner = typename value_type_t<decltype(v)>::Scalar;
2626
T_fvar_inner res = norm2(v.val());
2727
return fvar<T_fvar_inner>(res,

stan/math/prim/constraint/prob_constrain.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ inline auto prob_constrain(T&& x) {
4949
*/
5050
template <typename T, typename Lp>
5151
inline auto prob_constrain(T&& x, Lp& lp) {
52-
std::decay_t<T> log_inv_logit_x = log_inv_logit(x);
52+
plain_type_t<T> log_inv_logit_x = log_inv_logit(x);
5353
lp += log_inv_logit_x + log1m_inv_logit(std::forward<T>(x));
5454
return exp(std::move(log_inv_logit_x));
5555
}

stan/math/prim/fun/abs.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ template <typename Container,
8080
require_container_bt<std::is_arithmetic, Container>* = nullptr>
8181
inline auto abs(Container&& x) {
8282
return apply_vector_unary<Container>::apply(
83-
std::forward<Container>(x), [&](auto&& v) { return v.array().abs(); });
83+
std::forward<Container>(x), [](auto&& v) { return v.array().abs(); });
8484
}
8585

8686
namespace internal {

stan/math/prim/fun/as_array_or_scalar.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,13 @@ inline auto as_array_or_scalar(T&& v) {
5959

6060
/**
6161
* Converts a std::vector type to an array.
62-
*
62+
* @note The math library's reverse mode assumes that `Eigen::Map`
63+
* types are allocated and owned elsewhere so we cannot just return
64+
* back a map here else the reverse mode library
65+
* may try to access into a dangling pointer. Instead we wrap
66+
* the `Eigen::Map` in a `Holder` to trick the reverse mode library
67+
* into not thinking this is a map. The `.array().matrix()` inside the
68+
* holder is so that the holder thinks it is returning an expression.
6369
* @tparam T Type of scalar element.
6470
* @param v Specified vector.
6571
* @return Matrix converted to an array.

stan/math/prim/fun/cbrt.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ template <
4040
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr,
4141
require_container_t<T>* = nullptr>
4242
inline auto cbrt(T&& x) {
43-
return apply_scalar_unary<cbrt_fun, T>::apply(x);
43+
return apply_scalar_unary<cbrt_fun, T>::apply(std::forward<T>(x));
4444
}
4545

4646
} // namespace math

stan/math/prim/fun/cos.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ template <typename Container,
7777
require_container_bt<std::is_arithmetic, Container>* = nullptr>
7878
inline auto cos(Container&& x) {
7979
return apply_vector_unary<Container>::apply(
80-
std::forward<Container>(x), [&](auto&& v) { return v.array().cos(); });
80+
std::forward<Container>(x), [](auto&& v) { return v.array().cos(); });
8181
}
8282

8383
namespace internal {

stan/math/prim/fun/inv_square.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ template <typename Container,
2727
Container>* = nullptr,
2828
require_container_t<Container>* = nullptr>
2929
inline auto inv_square(Container&& x) {
30-
return inv(square(std::forward<Container>(x)));
30+
return make_holder([](auto&& v) {
31+
return inv(square(std::forward<decltype(v)>(v)));
32+
}, std::forward<Container>(x));
3133
}
3234

3335
/**

0 commit comments

Comments
 (0)