Skip to content

Commit 836a5ef

Browse files
committed
Removes many of the signatures for pow.
Fix expr tests so they can work with python > 3
1 parent 9239589 commit 836a5ef

7 files changed

Lines changed: 130 additions & 363 deletions

File tree

stan/math/fwd/fun/pow.hpp

Lines changed: 59 additions & 175 deletions
Original file line numberDiff line numberDiff line change
@@ -3,60 +3,66 @@
33

44
#include <stan/math/fwd/meta.hpp>
55
#include <stan/math/fwd/core.hpp>
6-
#include <stan/math/fwd/fun/sqrt.hpp>
76
#include <stan/math/fwd/fun/inv.hpp>
87
#include <stan/math/fwd/fun/inv_sqrt.hpp>
98
#include <stan/math/fwd/fun/inv_square.hpp>
9+
#include <stan/math/fwd/fun/sqrt.hpp>
10+
#include <stan/math/fwd/fun/square.hpp>
1011
#include <stan/math/prim/fun/pow.hpp>
1112
#include <cmath>
1213
#include <complex>
1314
#include <type_traits>
1415

1516
namespace stan {
1617
namespace math {
17-
18-
template <typename T>
19-
inline fvar<T> pow(const fvar<T>& x1, const fvar<T>& x2) {
20-
using std::log;
21-
using std::pow;
22-
T pow_x1_x2(pow(x1.val_, x2.val_));
23-
return fvar<T>(pow_x1_x2, (x2.d_ * log(x1.val_) + x2.val_ * x1.d_ / x1.val_)
24-
* pow_x1_x2);
25-
}
26-
27-
template <typename T, typename U, typename = require_arithmetic_t<U>>
28-
inline fvar<T> pow(const U& x1, const fvar<T>& x2) {
18+
/*
19+
*
20+
* @tparam T1 Either an `fvar`, `arithmetic`, or `complex` type with an inner `fvar` or `arithmetic` type.
21+
* @tparam T2 Either a `fvar`, `arithmetic`, or `complex` type with an inner `fvar` or `arithmetic` type.
22+
* @param x1 Base variable.
23+
* @param x2 Exponent variable.
24+
* @return Base raised to the exponent.
25+
*/
26+
template <typename T1, typename T2,
27+
require_any_fvar_t<base_type_t<T1>, base_type_t<T2>>* = nullptr,
28+
require_all_stan_scalar_t<T1, T2>* = nullptr>
29+
inline auto pow(const T1& x1, const T2& x2) {
2930
using std::log;
3031
using std::pow;
31-
T u = pow(x1, x2.val_);
32-
return fvar<T>(u, x2.d_ * log(x1) * u);
33-
}
34-
35-
template <typename T, typename U, typename = require_arithmetic_t<U>>
36-
inline fvar<T> pow(const fvar<T>& x1, const U& x2) {
37-
using std::pow;
38-
using std::sqrt;
39-
if (x2 == -2) {
40-
return inv_square(x1);
41-
}
42-
if (x2 == -1) {
43-
return inv(x1);
44-
}
45-
if (x2 == -0.5) {
46-
return inv_sqrt(x1);
47-
}
48-
if (x2 == 0.5) {
49-
return sqrt(x1);
50-
}
51-
if (x2 == 1.0) {
52-
return x1;
32+
if constexpr (is_complex<T1>::value || is_complex<T2>::value) {
33+
return internal::complex_pow(x1, x2);
34+
} else if constexpr (is_fvar<T1>::value && is_fvar<T2>::value) {
35+
auto pow_x1_x2(stan::math::pow(x1.val_, x2.val_));
36+
return T1(pow_x1_x2, (x2.d_ * stan::math::log(x1.val_) + x2.val_ * x1.d_ / x1.val_)
37+
* pow_x1_x2);
38+
} else if constexpr (is_fvar<T2>::value) {
39+
auto u = stan::math::pow(x1, x2.val_);
40+
return T2(u, x2.d_ * stan::math::log(x1) * u);
41+
} else {
42+
using std::sqrt;
43+
if (x2 == -2) {
44+
return stan::math::inv_square(x1);
45+
}
46+
if (x2 == -1) {
47+
return stan::math::inv(x1);
48+
}
49+
if (x2 == -0.5) {
50+
return stan::math::inv_sqrt(x1);
51+
}
52+
if (x2 == 0.5) {
53+
return stan::math::sqrt(x1);
54+
}
55+
if (x2 == 1.0) {
56+
return x1;
57+
}
58+
if (x2 == 2.0) {
59+
return stan::math::square(x1);
60+
}
61+
return T1(stan::math::pow(x1.val_, x2), x1.d_ * x2 * stan::math::pow(x1.val_, x2 - 1));
5362
}
54-
if (x2 == 2.0) {
55-
return square(x1);
56-
}
57-
return fvar<T>(pow(x1.val_, x2), x1.d_ * x2 * pow(x1.val_, x2 - 1));
5863
}
5964

65+
6066
// must uniquely match all pairs of:
6167
// { complex<fvar<V>>, complex<T>, fvar<V>, T }
6268
// with at least one fvar<V> and at least one complex, where T is arithmetic:
@@ -70,148 +76,26 @@ inline fvar<T> pow(const fvar<T>& x1, const U& x2) {
7076
// 8) fvar<V>, complex<T>
7177
// 9) T, complex<fvar<V>>
7278

73-
/**
74-
* Return the first argument raised to the power of the second argument.
75-
*
76-
* @param x first argument
77-
* @param y second argument
78-
* @return first argument to the power of the second argument
79-
*/
80-
template <typename V>
81-
inline std::complex<fvar<V>> pow(const std::complex<fvar<V>>& x,
82-
const std::complex<fvar<V>>& y) {
83-
return internal::complex_pow(x, y);
84-
}
8579

86-
/**
87-
* Return the first argument raised to the power of the second argument.
88-
*
89-
* @tparam V autodiff value type
90-
* @tparam T arithmetic type
91-
* @param x first argument
92-
* @param y second argument
93-
* @return first argument to the power of the second argument
94-
*/
95-
template <typename V, typename T, typename = require_arithmetic_t<T>>
96-
inline std::complex<fvar<V>> pow(const std::complex<fvar<V>>& x,
97-
const std::complex<T>& y) {
98-
return internal::complex_pow(x, y);
99-
}
10080

101-
/**
102-
* Return the first argument raised to the power of the second argument.
103-
*
104-
* @tparam V autodiff value type
105-
* @param x first argument
106-
* @param y second argument
107-
* @return first argument to the power of the second argument
108-
*/
109-
template <typename V>
110-
inline std::complex<fvar<V>> pow(const std::complex<fvar<V>>& x,
111-
const fvar<V>& y) {
112-
return internal::complex_pow(x, y);
113-
}
11481

11582
/**
116-
* Return the first argument raised to the power of the second argument.
117-
*
118-
* @tparam V autodiff value type
119-
* @tparam T arithmetic type
120-
* @param x first argument
121-
* @param y second argument
122-
* @return first argument to the power of the second argument
123-
*/
124-
template <typename V, typename T, typename = require_arithmetic_t<T>>
125-
inline std::complex<fvar<V>> pow(const std::complex<fvar<V>>& x, const T& y) {
126-
return internal::complex_pow(x, y);
127-
}
128-
129-
/**
130-
* Return the first argument raised to the power of the second argument.
131-
*
132-
* @tparam V autodiff value type
133-
* @tparam T arithmetic type
134-
* @param x first argument
135-
* @param y second argument
136-
* @return first argument to the power of the second argument
137-
*/
138-
template <typename V, typename T, typename = require_arithmetic_t<T>>
139-
inline std::complex<fvar<V>> pow(const std::complex<T>& x,
140-
const std::complex<fvar<V>>& y) {
141-
return internal::complex_pow(x, y);
142-
}
143-
144-
/**
145-
* Return the first argument raised to the power of the second argument.
146-
*
147-
* @tparam V autodiff value type
148-
* @tparam T arithmetic type
149-
* @param x first argument
150-
* @param y second argument
151-
* @return first argument to the power of the second argument
152-
*/
153-
template <typename V, typename T, typename = require_arithmetic_t<T>>
154-
inline std::complex<fvar<V>> pow(const std::complex<T>& x, const fvar<V>& y) {
155-
return internal::complex_pow(x, y);
156-
}
157-
158-
/**
159-
* Return the first argument raised to the power of the second argument.
160-
*
161-
* @tparam V autodiff value type
162-
* @param x first argument
163-
* @param y second argument
164-
* @return first argument to the power of the second argument
165-
*/
166-
template <typename V>
167-
inline std::complex<fvar<V>> pow(const fvar<V>& x,
168-
const std::complex<fvar<V>>& y) {
169-
return internal::complex_pow(x, y);
170-
}
171-
172-
/**
173-
* Return the first argument raised to the power of the second argument.
174-
*
175-
* @tparam V autodiff value type
176-
* @tparam T arithmetic type
177-
* @param x first argument
178-
* @param y second argument
179-
* @return first argument to the power of the second argument
180-
*/
181-
template <typename V, typename T, typename = require_arithmetic_t<T>>
182-
inline std::complex<fvar<V>> pow(const fvar<V>& x, const std::complex<T>& y) {
183-
return internal::complex_pow(x, y);
184-
}
185-
186-
/**
187-
* Return the first argument raised to the power of the second argument.
188-
*
189-
* @tparam V autodiff value type
190-
* @tparam T real type (`fvar<V>` or arithmetic)
191-
* @param x first argument
192-
* @param y second argument
193-
* @return first argument to the power of the second argument
194-
*/
195-
template <typename T, typename V, typename = require_arithmetic_t<T>>
196-
inline std::complex<fvar<V>> pow(const T& x, const std::complex<fvar<V>>& y) {
197-
return internal::complex_pow(x, y);
198-
}
199-
200-
/**
201-
* Return the first argument raised to the power of the second argument.
202-
*
203-
* Note: this overload is required because gcc still provides the
204-
* C++99 template function `pow(complex<T>, int)`, which introduces
205-
* an ambiguity.
83+
* Returns the elementwise raising of the first argument to the power of the
84+
* second argument.
20685
*
207-
* @tparam T autodiff value type
208-
* @param x first argument
209-
* @param y second argument
210-
* @return first argument to the power of the second argument
86+
* @tparam T1 type of first argument
87+
* @tparam T2 type of second argument
88+
* @param a first argument
89+
* @param b second argument
90+
* @return the elementwise raising of the first argument to the power of the
91+
* second argument.
21192
*/
212-
template <typename T>
213-
inline std::complex<fvar<T>> pow(const std::complex<fvar<T>>& x, int y) {
214-
return internal::complex_pow(x, y);
93+
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr,
94+
require_all_not_matrix_st<is_var, T1, T2>* = nullptr,
95+
require_any_fvar_t<base_type_t<T1>, base_type_t<T2>>* = nullptr>
96+
inline auto pow(const T1& a, const T2& b) {
97+
return apply_scalar_binary(
98+
a, b, [](const auto& c, const auto& d) { return stan::math::pow(c, d); });
21599
}
216100

217101
} // namespace math

stan/math/mix.hpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@
55
#include <stan/math/mix/fun.hpp>
66
#include <stan/math/mix/functor.hpp>
77

8+
#include <stan/math/rev/constraint.hpp>
9+
#include <stan/math/rev/core.hpp>
10+
#include <stan/math/rev/meta.hpp>
11+
#include <stan/math/rev/fun.hpp>
12+
#include <stan/math/rev/functor.hpp>
13+
#include <stan/math/rev/prob.hpp>
14+
815
#include <stan/math/fwd/constraint.hpp>
916
#include <stan/math/fwd/core.hpp>
1017
#include <stan/math/fwd/meta.hpp>
@@ -17,13 +24,6 @@
1724
#include <stan/math/opencl/rev_constraint.hpp>
1825
#endif
1926

20-
#include <stan/math/rev/constraint.hpp>
21-
#include <stan/math/rev/core.hpp>
22-
#include <stan/math/rev/meta.hpp>
23-
#include <stan/math/rev/fun.hpp>
24-
#include <stan/math/rev/functor.hpp>
25-
#include <stan/math/rev/prob.hpp>
26-
2727
#include <stan/math/prim.hpp>
2828

2929
#endif

stan/math/prim/fun/pow.hpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/fun/constants.hpp>
6+
#include <stan/math/prim/fun/square.hpp>
67
#include <stan/math/prim/functor/apply_scalar_binary.hpp>
78
#include <cmath>
89
#include <complex>
@@ -40,13 +41,28 @@ inline complex_return_t<U, V> complex_pow(const U& x, const V& y) {
4041
* argument.
4142
*/
4243
template <typename T1, typename T2,
43-
require_all_t<
44-
disjunction<is_complex<T1>, std::is_arithmetic<T1>>,
45-
disjunction<is_complex<T2>, std::is_arithmetic<T2>>>* = nullptr>
46-
inline auto pow(const T1& a, const T2& b) {
44+
require_arithmetic_t<T1>* = nullptr, require_arithmetic_t<T2>* = nullptr>
45+
inline auto pow(const std::complex<T1>& a, const std::complex<T2>& b) {
4746
return std::pow(a, b);
4847
}
4948

49+
template <typename T1, typename T2,
50+
require_arithmetic_t<T1>* = nullptr, require_arithmetic_t<T2>* = nullptr>
51+
inline auto pow(const T1& a, const std::complex<T2>& b) {
52+
return std::pow(a, b);
53+
}
54+
55+
template <typename T1, typename T2,
56+
require_arithmetic_t<T1>* = nullptr, require_arithmetic_t<T2>* = nullptr>
57+
inline auto pow(const std::complex<T1>& a, const T2& b) {
58+
return std::pow(a, b);
59+
}
60+
61+
template <typename T1, typename T2,
62+
require_arithmetic_t<T1>* = nullptr, require_arithmetic_t<T2>* = nullptr>
63+
inline auto pow(const T1& a, const T2& b) {
64+
return std::pow(a, b);
65+
}
5066
/**
5167
* Returns the elementwise raising of the first argument to the power of the
5268
* second argument.
@@ -60,7 +76,7 @@ inline auto pow(const T1& a, const T2& b) {
6076
*/
6177
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr,
6278
require_all_not_matrix_st<is_var, T1, T2>* = nullptr,
63-
require_all_st_arithmetic<T1, T2>* = nullptr>
79+
require_all_arithmetic_t<base_type_t<T1>, base_type_t<T2>>* = nullptr>
6480
inline auto pow(const T1& a, const T2& b) {
6581
return apply_scalar_binary(
6682
a, b, [](const auto& c, const auto& d) { return stan::math::pow(c, d); });

stan/math/prim/meta/base_type.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include <stan/math/prim/fun/Eigen.hpp>
55
#include <stan/math/prim/meta/is_complex.hpp>
66
#include <stan/math/prim/meta/is_eigen.hpp>
7-
#include <stan/math/prim/meta/value_type.hpp>
87
#include <stan/math/prim/meta/is_vector.hpp>
98
#include <type_traits>
109
#include <vector>

stan/math/prim/meta/is_var.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#ifndef STAN_MATH_PRIM_META_IS_VAR_HPP
22
#define STAN_MATH_PRIM_META_IS_VAR_HPP
33

4-
#include <stan/math/prim/meta/require_helpers.hpp>
4+
#include <stan/math/prim/meta.hpp>
55

66
#include <type_traits>
77

0 commit comments

Comments
 (0)