33
44#include < stan/math/fwd/meta.hpp>
55#include < stan/math/fwd/core.hpp>
6- #include < stan/math/fwd/fun/sqrt.hpp>
76#include < stan/math/fwd/fun/inv.hpp>
87#include < stan/math/fwd/fun/inv_sqrt.hpp>
98#include < stan/math/fwd/fun/inv_square.hpp>
9+ #include < stan/math/fwd/fun/sqrt.hpp>
10+ #include < stan/math/fwd/fun/square.hpp>
1011#include < stan/math/prim/fun/pow.hpp>
1112#include < cmath>
1213#include < complex>
1314#include < type_traits>
1415
1516namespace stan {
1617namespace math {
17-
18- template < typename T>
19- inline fvar<T> pow ( const fvar<T>& x1, const fvar<T>& x2) {
20- using std::log;
21- using std::pow;
22- T pow_x1_x2 ( pow (x1. val_ , x2. val_ ));
23- return fvar<T>(pow_x1_x2, (x2. d_ * log (x1. val_ ) + x2. val_ * x1. d_ / x1. val_ )
24- * pow_x1_x2);
25- }
26-
27- template < typename T, typename U, typename = require_arithmetic_t <U> >
28- inline fvar<T> pow (const U & x1, const fvar<T> & x2) {
18+ /*
19+ *
20+ * @tparam T1 Either an ` fvar`, `arithmetic`, or `complex` type with an inner ` fvar` or `arithmetic` type.
21+ * @tparam T2 Either a `fvar`, `arithmetic`, or `complex` type with an inner `fvar` or `arithmetic` type.
22+ * @param x1 Base variable.
23+ * @param x2 Exponent variable.
24+ * @ return Base raised to the exponent.
25+ */
26+ template < typename T1, typename T2,
27+ require_any_fvar_t < base_type_t <T1>, base_type_t <T2>>* = nullptr ,
28+ require_all_stan_scalar_t <T1, T2>* = nullptr >
29+ inline auto pow (const T1 & x1, const T2 & x2) {
2930 using std::log;
3031 using std::pow;
31- T u = pow (x1, x2.val_ );
32- return fvar<T>(u, x2.d_ * log (x1) * u);
33- }
34-
35- template <typename T, typename U, typename = require_arithmetic_t <U>>
36- inline fvar<T> pow (const fvar<T>& x1, const U& x2) {
37- using std::pow;
38- using std::sqrt;
39- if (x2 == -2 ) {
40- return inv_square (x1);
41- }
42- if (x2 == -1 ) {
43- return inv (x1);
44- }
45- if (x2 == -0.5 ) {
46- return inv_sqrt (x1);
47- }
48- if (x2 == 0.5 ) {
49- return sqrt (x1);
50- }
51- if (x2 == 1.0 ) {
52- return x1;
32+ if constexpr (is_complex<T1>::value || is_complex<T2>::value) {
33+ return internal::complex_pow (x1, x2);
34+ } else if constexpr (is_fvar<T1>::value && is_fvar<T2>::value) {
35+ auto pow_x1_x2 (stan::math::pow (x1.val_ , x2.val_ ));
36+ return T1 (pow_x1_x2, (x2.d_ * stan::math::log (x1.val_ ) + x2.val_ * x1.d_ / x1.val_ )
37+ * pow_x1_x2);
38+ } else if constexpr (is_fvar<T2>::value) {
39+ auto u = stan::math::pow (x1, x2.val_ );
40+ return T2 (u, x2.d_ * stan::math::log (x1) * u);
41+ } else {
42+ using std::sqrt;
43+ if (x2 == -2 ) {
44+ return stan::math::inv_square (x1);
45+ }
46+ if (x2 == -1 ) {
47+ return stan::math::inv (x1);
48+ }
49+ if (x2 == -0.5 ) {
50+ return stan::math::inv_sqrt (x1);
51+ }
52+ if (x2 == 0.5 ) {
53+ return stan::math::sqrt (x1);
54+ }
55+ if (x2 == 1.0 ) {
56+ return x1;
57+ }
58+ if (x2 == 2.0 ) {
59+ return stan::math::square (x1);
60+ }
61+ return T1 (stan::math::pow (x1.val_ , x2), x1.d_ * x2 * stan::math::pow (x1.val_ , x2 - 1 ));
5362 }
54- if (x2 == 2.0 ) {
55- return square (x1);
56- }
57- return fvar<T>(pow (x1.val_ , x2), x1.d_ * x2 * pow (x1.val_ , x2 - 1 ));
5863}
5964
65+
6066// must uniquely match all pairs of:
6167// { complex<fvar<V>>, complex<T>, fvar<V>, T }
6268// with at least one fvar<V> and at least one complex, where T is arithmetic:
@@ -70,148 +76,26 @@ inline fvar<T> pow(const fvar<T>& x1, const U& x2) {
7076// 8) fvar<V>, complex<T>
7177// 9) T, complex<fvar<V>>
7278
73- /* *
74- * Return the first argument raised to the power of the second argument.
75- *
76- * @param x first argument
77- * @param y second argument
78- * @return first argument to the power of the second argument
79- */
80- template <typename V>
81- inline std::complex <fvar<V>> pow (const std::complex <fvar<V>>& x,
82- const std::complex <fvar<V>>& y) {
83- return internal::complex_pow (x, y);
84- }
8579
86- /* *
87- * Return the first argument raised to the power of the second argument.
88- *
89- * @tparam V autodiff value type
90- * @tparam T arithmetic type
91- * @param x first argument
92- * @param y second argument
93- * @return first argument to the power of the second argument
94- */
95- template <typename V, typename T, typename = require_arithmetic_t <T>>
96- inline std::complex <fvar<V>> pow (const std::complex <fvar<V>>& x,
97- const std::complex <T>& y) {
98- return internal::complex_pow (x, y);
99- }
10080
101- /* *
102- * Return the first argument raised to the power of the second argument.
103- *
104- * @tparam V autodiff value type
105- * @param x first argument
106- * @param y second argument
107- * @return first argument to the power of the second argument
108- */
109- template <typename V>
110- inline std::complex <fvar<V>> pow (const std::complex <fvar<V>>& x,
111- const fvar<V>& y) {
112- return internal::complex_pow (x, y);
113- }
11481
11582/* *
116- * Return the first argument raised to the power of the second argument.
117- *
118- * @tparam V autodiff value type
119- * @tparam T arithmetic type
120- * @param x first argument
121- * @param y second argument
122- * @return first argument to the power of the second argument
123- */
124- template <typename V, typename T, typename = require_arithmetic_t <T>>
125- inline std::complex <fvar<V>> pow (const std::complex <fvar<V>>& x, const T& y) {
126- return internal::complex_pow (x, y);
127- }
128-
129- /* *
130- * Return the first argument raised to the power of the second argument.
131- *
132- * @tparam V autodiff value type
133- * @tparam T arithmetic type
134- * @param x first argument
135- * @param y second argument
136- * @return first argument to the power of the second argument
137- */
138- template <typename V, typename T, typename = require_arithmetic_t <T>>
139- inline std::complex <fvar<V>> pow (const std::complex <T>& x,
140- const std::complex <fvar<V>>& y) {
141- return internal::complex_pow (x, y);
142- }
143-
144- /* *
145- * Return the first argument raised to the power of the second argument.
146- *
147- * @tparam V autodiff value type
148- * @tparam T arithmetic type
149- * @param x first argument
150- * @param y second argument
151- * @return first argument to the power of the second argument
152- */
153- template <typename V, typename T, typename = require_arithmetic_t <T>>
154- inline std::complex <fvar<V>> pow (const std::complex <T>& x, const fvar<V>& y) {
155- return internal::complex_pow (x, y);
156- }
157-
158- /* *
159- * Return the first argument raised to the power of the second argument.
160- *
161- * @tparam V autodiff value type
162- * @param x first argument
163- * @param y second argument
164- * @return first argument to the power of the second argument
165- */
166- template <typename V>
167- inline std::complex <fvar<V>> pow (const fvar<V>& x,
168- const std::complex <fvar<V>>& y) {
169- return internal::complex_pow (x, y);
170- }
171-
172- /* *
173- * Return the first argument raised to the power of the second argument.
174- *
175- * @tparam V autodiff value type
176- * @tparam T arithmetic type
177- * @param x first argument
178- * @param y second argument
179- * @return first argument to the power of the second argument
180- */
181- template <typename V, typename T, typename = require_arithmetic_t <T>>
182- inline std::complex <fvar<V>> pow (const fvar<V>& x, const std::complex <T>& y) {
183- return internal::complex_pow (x, y);
184- }
185-
186- /* *
187- * Return the first argument raised to the power of the second argument.
188- *
189- * @tparam V autodiff value type
190- * @tparam T real type (`fvar<V>` or arithmetic)
191- * @param x first argument
192- * @param y second argument
193- * @return first argument to the power of the second argument
194- */
195- template <typename T, typename V, typename = require_arithmetic_t <T>>
196- inline std::complex <fvar<V>> pow (const T& x, const std::complex <fvar<V>>& y) {
197- return internal::complex_pow (x, y);
198- }
199-
200- /* *
201- * Return the first argument raised to the power of the second argument.
202- *
203- * Note: this overload is required because gcc still provides the
204- * C++99 template function `pow(complex<T>, int)`, which introduces
205- * an ambiguity.
83+ * Returns the elementwise raising of the first argument to the power of the
84+ * second argument.
20685 *
207- * @tparam T autodiff value type
208- * @param x first argument
209- * @param y second argument
210- * @return first argument to the power of the second argument
86+ * @tparam T1 type of first argument
87+ * @tparam T2 type of second argument
88+ * @param a first argument
89+ * @param b second argument
90+ * @return the elementwise raising of the first argument to the power of the
91+ * second argument.
21192 */
212- template <typename T>
213- inline std::complex <fvar<T>> pow (const std::complex <fvar<T>>& x, int y) {
214- return internal::complex_pow (x, y);
93+ template <typename T1, typename T2, require_any_container_t <T1, T2>* = nullptr ,
94+ require_all_not_matrix_st<is_var, T1, T2>* = nullptr ,
95+ require_any_fvar_t <base_type_t <T1>, base_type_t <T2>>* = nullptr >
96+ inline auto pow (const T1& a, const T2& b) {
97+ return apply_scalar_binary (
98+ a, b, [](const auto & c, const auto & d) { return stan::math::pow (c, d); });
21599}
216100
217101} // namespace math
0 commit comments