Skip to content

Commit ecd9db7

Browse files
committed
merge to upstream
2 parents e0514cd + a01c01f commit ecd9db7

3 files changed

Lines changed: 55 additions & 34 deletions

File tree

stan/math/fwd/fun/pow.hpp

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,27 @@ namespace stan {
1818
namespace math {
1919
/*
2020
*
21-
* @tparam T1 Either an `fvar`, `arithmetic`, or `complex` type with an inner `fvar` or `arithmetic` type.
22-
* @tparam T2 Either a `fvar`, `arithmetic`, or `complex` type with an inner `fvar` or `arithmetic` type.
21+
* @tparam T1 Either an `fvar`, `arithmetic`, or `complex` type with an inner
22+
* `fvar` or `arithmetic` type.
23+
* @tparam T2 Either a `fvar`, `arithmetic`, or `complex` type with an inner
24+
* `fvar` or `arithmetic` type.
2325
* @param x1 Base variable.
2426
* @param x2 Exponent variable.
2527
* @return Base raised to the exponent.
2628
*/
27-
template <typename T1, typename T2,
28-
require_any_fvar_t<base_type_t<T1>, base_type_t<T2>>* = nullptr,
29-
require_all_stan_scalar_t<T1, T2>* = nullptr>
29+
template <typename T1, typename T2,
30+
require_any_fvar_t<base_type_t<T1>, base_type_t<T2>>* = nullptr,
31+
require_all_stan_scalar_t<T1, T2>* = nullptr>
3032
inline auto pow(const T1& x1, const T2& x2) {
3133
using std::log;
3234
using std::pow;
3335
if constexpr (is_complex<T1>::value || is_complex<T2>::value) {
3436
return internal::complex_pow(x1, x2);
3537
} else if constexpr (is_fvar<T1>::value && is_fvar<T2>::value) {
3638
auto pow_x1_x2(stan::math::pow(x1.val_, x2.val_));
37-
return T1(pow_x1_x2, (x2.d_ * stan::math::log(x1.val_) + x2.val_ * x1.d_ / x1.val_)
38-
* pow_x1_x2);
39+
return T1(pow_x1_x2,
40+
(x2.d_ * stan::math::log(x1.val_) + x2.val_ * x1.d_ / x1.val_)
41+
* pow_x1_x2);
3942
} else if constexpr (is_fvar<T2>::value) {
4043
auto u = stan::math::pow(x1, x2.val_);
4144
return T2(u, x2.d_ * stan::math::log(x1) * u);
@@ -59,10 +62,26 @@ inline auto pow(const T1& x1, const T2& x2) {
5962
if (x2 == 2.0) {
6063
return stan::math::square(x1);
6164
}
62-
return T1(stan::math::pow(x1.val_, x2), x1.d_ * x2 * stan::math::pow(x1.val_, x2 - 1));
65+
return T1(stan::math::pow(x1.val_, x2),
66+
x1.d_ * x2 * stan::math::pow(x1.val_, x2 - 1));
6367
}
6468
}
6569

70+
<<<<<<< HEAD
71+
=======
72+
// must uniquely match all pairs of:
73+
// { complex<fvar<V>>, complex<T>, fvar<V>, T }
74+
// with at least one fvar<V> and at least one complex, where T is arithmetic:
75+
// 1) complex<fvar<V>>, complex<fvar<V>>
76+
// 2) complex<fvar<V>>, complex<T>
77+
// 3) complex<fvar<V>>, fvar<V>
78+
// 4) complex<fvar<V>>, T
79+
// 5) complex<T>, complex<fvar<V>>
80+
// 6) complex<T>, fvar<V>
81+
// 7) fvar<V>, complex<fvar<V>>
82+
// 8) fvar<V>, complex<T>
83+
// 9) T, complex<fvar<V>>
84+
>>>>>>> origin/fix/pow-overload-resolution
6685

6786
/**
6887
* Returns the elementwise raising of the first argument to the power of the

stan/math/prim/fun/pow.hpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,26 +40,26 @@ inline complex_return_t<U, V> complex_pow(const U& x, const V& y) {
4040
* @return the first argument raised to the power of the second
4141
* argument.
4242
*/
43-
template <typename T1, typename T2,
44-
require_arithmetic_t<T1>* = nullptr, require_arithmetic_t<T2>* = nullptr>
43+
template <typename T1, typename T2, require_arithmetic_t<T1>* = nullptr,
44+
require_arithmetic_t<T2>* = nullptr>
4545
inline auto pow(const std::complex<T1>& a, const std::complex<T2>& b) {
4646
return std::pow(a, b);
4747
}
4848

49-
template <typename T1, typename T2,
50-
require_arithmetic_t<T1>* = nullptr, require_arithmetic_t<T2>* = nullptr>
49+
template <typename T1, typename T2, require_arithmetic_t<T1>* = nullptr,
50+
require_arithmetic_t<T2>* = nullptr>
5151
inline auto pow(const T1& a, const std::complex<T2>& b) {
5252
return std::pow(a, b);
5353
}
5454

55-
template <typename T1, typename T2,
56-
require_arithmetic_t<T1>* = nullptr, require_arithmetic_t<T2>* = nullptr>
55+
template <typename T1, typename T2, require_arithmetic_t<T1>* = nullptr,
56+
require_arithmetic_t<T2>* = nullptr>
5757
inline auto pow(const std::complex<T1>& a, const T2& b) {
5858
return std::pow(a, b);
5959
}
6060

61-
template <typename T1, typename T2,
62-
require_arithmetic_t<T1>* = nullptr, require_arithmetic_t<T2>* = nullptr>
61+
template <typename T1, typename T2, require_arithmetic_t<T1>* = nullptr,
62+
require_arithmetic_t<T2>* = nullptr>
6363
inline auto pow(const T1& a, const T2& b) {
6464
return std::pow(a, b);
6565
}

stan/math/rev/fun/pow.hpp

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,10 @@ namespace math {
6161
\end{cases}
6262
\f]
6363
*
64-
* @tparam Scal1 Either a `var`, `arithmetic`, or `complex` type with an inner `var` or `arithmetic` type.
65-
* @tparam Scal2 Either a `var`, `arithmetic`, or `complex` type with an inner `var` or `arithmetic` type.
64+
* @tparam Scal1 Either a `var`, `arithmetic`, or `complex` type with an inner
65+
`var` or `arithmetic` type.
66+
* @tparam Scal2 Either a `var`, `arithmetic`, or `complex` type with an inner
67+
`var` or `arithmetic` type.
6668
* @param base Base variable.
6769
* @param exponent Exponent variable.
6870
* @return Base raised to the exponent.
@@ -89,22 +91,23 @@ inline auto pow(const Scal1& base, const Scal2& exponent) {
8991
return inv_sqrt(base);
9092
}
9193
}
92-
return make_callback_var(
93-
std::pow(value_of(base), value_of(exponent)),
94-
[base, exponent](auto&& vi) mutable {
95-
if (value_of(base) == 0.0) {
96-
return; // partials zero, avoids 0 & log(0)
97-
}
98-
const double vi_mul = vi.adj() * vi.val();
94+
return make_callback_var(std::pow(value_of(base), value_of(exponent)),
95+
[base, exponent](auto&& vi) mutable {
96+
if (value_of(base) == 0.0) {
97+
return; // partials zero, avoids 0 & log(0)
98+
}
99+
const double vi_mul = vi.adj() * vi.val();
99100

100-
if (!is_constant<Scal1>::value) {
101-
forward_as<var>(base).adj()
102-
+= vi_mul * value_of(exponent) / value_of(base);
103-
}
104-
if (!is_constant<Scal2>::value) {
105-
forward_as<var>(exponent).adj() += vi_mul * std::log(value_of(base));
106-
}
107-
});
101+
if (!is_constant<Scal1>::value) {
102+
forward_as<var>(base).adj()
103+
+= vi_mul * value_of(exponent)
104+
/ value_of(base);
105+
}
106+
if (!is_constant<Scal2>::value) {
107+
forward_as<var>(exponent).adj()
108+
+= vi_mul * std::log(value_of(base));
109+
}
110+
});
108111
}
109112
}
110113

@@ -268,7 +271,6 @@ inline auto pow(Scal1 base, const Mat1& exponent) {
268271
return ret_type(ret);
269272
}
270273

271-
272274
/**
273275
* Returns the elementwise raising of the first argument to the power of the
274276
* second argument.

0 commit comments

Comments
 (0)