Skip to content

Commit 76d2a05

Browse files
authored
Merge pull request #2928 from stan-dev/feature/reverse-mode-move-semantics
Allow arena_matrix to use move semantics
2 parents 00314f3 + 91ea4c1 commit 76d2a05

72 files changed

Lines changed: 958 additions & 335 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

doxygen/contributor_help_pages/common_pitfalls.md

Lines changed: 138 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,6 @@ The implementation of @ref stan::math::make_holder is [here](https://github.com/
154154

155155
### Move Semantics
156156

157-
In general, Stan Math does not use move semantics very often.
158-
This is because of our arena allocator.
159157
Move semantics generally work as
160158

161159
```cpp
@@ -179,6 +177,96 @@ We can see in the above that the standard style of a move (the constructor takin
179177
But in Stan, particularly for reverse mode, we need to keep memory around even if it's only temporary for when we call the gradient calculations in the reverse pass.
180178
And since memory for reverse mode is stored in our arena allocator no copying happens in the first place.
181179
180+
Functions for Stan Math's reverse mode autodiff should use [_perfect forwarding_](https://drewcampbell92.medium.com/understanding-move-semantics-and-perfect-forwarding-part-3-65575d523ff8) arguments. Perfect forwarding arguments use a template parameter wit no attributes such as `const` and `volatile` and have a double ampersand `&&` next to them.
181+
182+
```c++
183+
template <typename T>
184+
auto my_function(T&& x) {
185+
return my_other_function(std::forward<T>(x));
186+
}
187+
```
188+
189+
The `std::forward<T>` in the in the code above tells the compiler that if `T` is deduced to be an rvalue reference (such as `Eigen::MatrixXd&&`), then it should be moved to `my_other_function`, where there it can possibly use another objects move constructor to reuse memory.
190+
A perfect forwarding argument of a function accepts any reference type as its input argument.
191+
The above signature is equivalent to writing out several functions with different reference types
192+
193+
```c++
194+
// Accepts a plain lvalue reference
195+
auto my_function(Eigen::MatrixXd& x) {
196+
return my_other_function(x);
197+
}
198+
// Accepts a const lvalue reference
199+
auto my_function(const Eigen::MatrixXd& x) {
200+
return my_other_function(x);
201+
}
202+
// Accepts an rvalue reference
203+
auto my_function(Eigen::MatrixXd&& x) {
204+
return my_other_function(std::move(x));
205+
}
206+
// Accepts a const rvalue reference
207+
auto my_function(const Eigen::MatrixXd&& x) {
208+
return my_other_function(std::move(x));
209+
}
210+
```
211+
212+
In Stan, perfect forwarding is used in reverse mode functions which can accept an Eigen matrix type.
213+
214+
```c++
215+
template <typename T, require_eigen_vt<is_var, T>* = nullptr>
216+
inline auto sin(T&& x) {
217+
// Store `x` on the arena
218+
arena_t<T> x_arena(std::forward<T>(x));
219+
arena_t<T> ret(x_arena.val().array().sin().matrix());
220+
reverse_pass_callback([x_arena, ret] mutable {
221+
x_arena.adj() += ret.adj().cwiseProduct(x_arena.val().array().cos().matrix());
222+
});
223+
return ret;
224+
}
225+
```
226+
227+
Let's go through the above line by line.
228+
229+
```c++
230+
template <typename T, require_eigen_vt<is_var, T>* = nullptr>
231+
inline auto sin(T&& x) {
232+
```
233+
234+
The signature for this function has a template `T` that is required to be an Eigen type with a `value_type` that is a `var` type.
235+
The template parameter `T` is then used in the signature as an perfect forwarding argument.
236+
237+
```c++
238+
// Store `x` on the arena
239+
arena_t<T> x_arena(std::forward<T>(x));
240+
```
241+
242+
The input is stored in the arena, which is where the perfect forwarding magic actually occurs.
243+
If `T` is an lvalue type such as `Eigen::MatrixXd&` then `arena_matrix` will use it's copy constructor, creating new memory in Stan's arena allocator and then copying the values of `x` into that memory.
244+
But if `T` was a temporary rvalue type such as `Eigen::MatrixXd&&`, then the `arena_matrix` class will use it's move constructor to place the temporary matrix in Stan's `var_alloc_stack_`.
245+
The `var_alloc_stick_` is used to hold objects that were created outside of the arena allocator but need to be deleted when the arena allocator is cleared.
246+
This allows the `arena_matrix` to reuse the memory from the temporary matrix. Then the matrix will be deleted once arena allocator requests memory to be cleared.
247+
248+
```c++
249+
arena_t<T> ret(x_arena.val().array().sin().matrix());
250+
```
251+
252+
This construction of an `arena_matrix` will *not* use the move constructor for `arena_matrix`.
253+
Here, `x_arena` is an `arena_matrix<T>`, which is then wrapped in an expression to compute the elementwise `sin`.
254+
That expression will be evaluated into new memory allocated in the arena allocator and then a pointer to it will be stored in the `arena_matrix.`
255+
256+
```c++
257+
reverse_pass_callback([x_arena, ret] mutable {
258+
x_arena.adj() += ret.adj().cwiseProduct(x_arena.val().array().cos().matrix());
259+
});
260+
return ret;
261+
```
262+
263+
The rest of this code follows the standard format for the rest of Stan Math's reverse mode that accepts Eigen types as input.
264+
The `reverse_pass_callback` function accepts a lambda as input and places the lambda in Stan's callback stack to be called later when `grad()` is called by the user.
265+
Since `arena_matrix` types only store a pointer to memory allocated elsewhere they are copied into the lambda.
266+
The body of the lambda holds the gradient calculation needed for the reverse mode pass.
267+
268+
Then finally `ret`, the `arena_matrix` type is returned by the function.
269+
182270
When working with arithmetic types, keep in mind that moving Scalars is often less optimal than simply taking their copy.
183271
For instance, Stan's `var` type uses the pointer to implementation (PIMPL) pattern, so it simply holds a pointer of size 8 bytes.
184272
A `double` is also 8 bytes which just so happens to fit exactly in a [word](https://en.wikipedia.org/wiki/Word_(computer_architecture)) of most modern CPUs with at least 64-byte cache lines.
@@ -190,6 +278,45 @@ The general rules to follow for passing values to a function are:
190278
2. If you are writing a function for reverse mode, pass values by `const&`
191279
3. In prim, if you are confident and working with larger types, use perfect forwarding to pass values that can be moved from. Otherwise simply pass values by `const&`.
192280

281+
### Using auto is Dangerous With Eigen Matrix Functions in Reverse Mode
282+
283+
The use of auto with the Stan Math library should be used with care, like in [Eigen](https://eigen.tuxfamily.org/dox/TopicPitfalls.html).
284+
Along with the cautions mentioned in the Eigen docs, there are also memory considerations when using reverse mode automatic differentiation.
285+
When returning from a function in the Stan Math library with an Eigen matrix output with a scalar `var` type, the actual returned type will often be an `arena_matrix<Eigen::Matrix<...>>`.
286+
The `arena_matrix` class is an Eigen matrix where the underlying array of memory is located in Stan's memory arena.
287+
The `arena_matrix` that is returned by Math functions is normally the same one resting in the callback used to calculate gradients in the reverse pass.
288+
Directly changing the elements of this matrix would also change the memory the reverse pass callback sees which would result in incorrect calculations.
289+
290+
The simple solution to this is that when you use a math library function that returns a matrix and then want to assign to any of the individual elements of the matrix, assign to an actual Eigen matrix type instead of using auto.
291+
In the below example, we see the first case which uses auto and will change the memory of the `arena_matrix` returned in the callback for multiply's reverse mode.
292+
Directly below it is the safe version, which just directly assigns to an Eigen matrix type and is safe to do element insertion into.
293+
294+
```c++
295+
Eigen::Matrix<var, -1, 1> y;
296+
Eigen::Matrix<var, -1, -1> X;
297+
// Bad!! Will change memory used by reverse pass callback within multiply!
298+
auto mu = multiply(X, y);
299+
mu(4) = 1.0;
300+
// Good! Will not change memory used by reverse pass callback within multiply
301+
Eigen::Matrix<var, -1, 1> mu_good = multiply(X, y);
302+
mu_good(4) = 1.0;
303+
```
304+
305+
The reason we do this is for cases where function returns are passed to other functions.
306+
An `arena_matrix` will always make a shallow copy when being constructed from another `arena_matrix`, which lets the functions avoid unnecessary copies.
307+
308+
```c++
309+
Eigen::Matrix<var, -1, 1> y1;
310+
Eigen::Matrix<var, -1, -1> X1;
311+
Eigen::Matrix<var, -1, 1> y2;
312+
Eigen::Matrix<var, -1, -1> X2;
313+
auto mu1 = multiply(X1, y1);
314+
auto mu2 = multiply(X2, y2);
315+
// Inputs not copied in this case!
316+
auto z = add(mu1, mu2);
317+
```
318+
319+
193320
### Passing variables that need destructors called after the reverse pass (`make_chainable_ptr`)
194321

195322
When possible, non-arena variables should be copied to the arena to be used in the reverse pass.
@@ -242,22 +369,17 @@ grad();
242369
```
243370

244371
Now `res` is `innocent_return` and we've changed one of the elements of `innocent_return`, but that is also going to change the element of `res` which is being used in our reverse pass callback!
245-
The answer for this is simple but sadly requires a copy.
246372

247-
```cpp
248-
template <typename EigVec, require_eigen_vt<is_var, EigVec>* = nullptr>
249-
inline var cool_fun(const EigVec& v) {
250-
arena_t<EigVec> arena_v(v);
251-
arena_t<EigVec> res = arena_v.val().array() * arena_v.val().array();
252-
reverse_pass_callback([res, arena_v]() mutable {
253-
arena_v.adj().array() += (2.0 * res.adj().array()) * arena_v.val().array();
254-
});
255-
return plain_type_t<EigVec>(res);
256-
}
257-
```
373+
Care must be taken by end users of Stan Math by using `auto` with caution.
374+
When a user wishes to manipulate the coefficients of a matrix that is a return from a function in Stan Math, they should assign the matrix to a plain Eigen type.
258375

259-
we make a deep copy of the return whose inner `vari` will not be the same, but the `var` will produce a new copy of the pointer to the `vari`.
260-
Now the user code above will be protected, and it is safe for them to assign to individual elements of the `auto` returned matrix.
376+
```c++
377+
Eigen::Matrix<var, -1, 1> x = Eigen::Matrix<double, -1, 1>::Random(5);
378+
Eigen::MatrixXd actually_innocent_return = cool_fun(x);
379+
actually_innocent_return.coeffRef(3) = var(3.0);
380+
auto still_unsafe_return = cool_fun2(actually_innocent_return);
381+
grad();
382+
```
261383

262384
### Const correctness, reverse mode autodiff, and arena types
263385

lib/tbb_2020.3/STAN_CHANGES

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,10 @@ This file documents changes done for the stan-math project
66
- build/windows.inc patches for RTools make:
77
- L15 changed setting to use '?=', allowing override
88
- L25,L113,L114 added additional '/' to each cmd flag
9+
10+
- Support for Windows ARM64 with RTools:
11+
- build/Makefile.tbb
12+
- L94 Wrapped the use of `--version-script` export in conditional on non-WINARM64
13+
- build/windows.gcc.ino
14+
- L84 Wrapped the use of `-flifetime-dse` flag in conditional on non-WINARM64
15+

lib/tbb_2020.3/build/Makefile.tbb

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,11 @@ ifneq (,$(TBB.DEF))
9191
tbb.def: $(TBB.DEF) $(TBB.LST)
9292
$(CPLUS) $(PREPROC_ONLY) $< $(CPLUS_FLAGS) $(INCLUDES) > $@
9393

94-
LIB_LINK_FLAGS += $(EXPORT_KEY)tbb.def
94+
# LLVM on Windows doesn't need --version-script export
95+
# https://reviews.llvm.org/D63743
96+
ifeq (, $(WINARM64))
97+
LIB_LINK_FLAGS += $(EXPORT_KEY)tbb.def
98+
endif
9599
$(TBB.DLL): tbb.def
96100
endif
97101

lib/tbb_2020.3/build/windows.gcc.inc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,11 @@ endif
8080
# gcc 6.0 and later have -flifetime-dse option that controls
8181
# elimination of stores done outside the object lifetime
8282
ifeq (ok,$(call detect_js,/minversion gcc 6.0))
83-
# keep pre-contruction stores for zero initialization
84-
DSE_KEY = -flifetime-dse=1
83+
# Clang does not support -flifetime-dse
84+
ifeq (, $(WINARM64))
85+
# keep pre-contruction stores for zero initialization
86+
DSE_KEY = -flifetime-dse=1
87+
endif
8588
endif
8689

8790
ifeq ($(cfg), release)

make/compiler_flags

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ endif
1717

1818
## Set OS specific library filename extensions
1919
ifeq ($(OS),Windows_NT)
20+
WINARM64 := $(shell echo | $(CXX) -E -dM - | findstr __aarch64__)
2021
LIBRARY_SUFFIX ?= .dll
2122
endif
2223

@@ -271,8 +272,13 @@ CXXFLAGS_TBB ?= -I $(TBB_INC)
271272
else
272273
CXXFLAGS_TBB ?= -I $(TBB)/include
273274
endif
275+
LDFLAGS_TBB ?= -Wl,-L,"$(TBB_LIB)" -Wl,--disable-new-dtags
276+
277+
# Windows LLVM/Clang does not support -rpath, but is not needed on Windows anyway
278+
ifeq ($(WINARM64),)
279+
LDFLAGS_TBB += -Wl,-rpath,"$(TBB_LIB)"
280+
endif
274281

275-
LDFLAGS_TBB ?= -Wl,-L,"$(TBB_LIB)" -Wl,-rpath,"$(TBB_LIB)" -Wl,--disable-new-dtags
276282
LDLIBS_TBB ?= -ltbb
277283

278284
else
@@ -290,7 +296,12 @@ ifeq ($(OS),Linux)
290296
endif
291297

292298
CXXFLAGS_TBB ?= -I $(TBB)/include
293-
LDFLAGS_TBB ?= -Wl,-L,"$(TBB_BIN_ABSOLUTE_PATH)" -Wl,-rpath,"$(TBB_BIN_ABSOLUTE_PATH)" $(LDFLAGS_FLTO_FLTO) $(LDFLAGS_OPTIM_TBB)
299+
LDFLAGS_TBB ?= -Wl,-L,"$(TBB_BIN_ABSOLUTE_PATH)" $(LDFLAGS_FLTO_FLTO) $(LDFLAGS_OPTIM_TBB)
300+
301+
# Windows LLVM/Clang does not support -rpath, but is not needed on Windows anyway
302+
ifeq ($(WINARM64),)
303+
LDFLAGS_TBB += -Wl,-rpath,"$(TBB_BIN_ABSOLUTE_PATH)"
304+
endif
294305
LDLIBS_TBB ?= -ltbb
295306

296307
endif

make/libraries

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ CPPLINT ?= $(MATH)lib/cpplint_1.4.5
2323
# Fortran bindings which we do not need for stan-math. Thus these targets
2424
# are ignored here. This convention was introduced with 4.0.
2525
##
26+
ifndef SUNDIALS_TARGETS
2627

2728
SUNDIALS_CVODES := $(patsubst %.c,%.o,\
2829
$(wildcard $(SUNDIALS)/src/cvodes/*.c) \
@@ -87,7 +88,7 @@ $(STAN_SUNDIALS_HEADERS) : $(SUNDIALS_TARGETS)
8788
clean-sundials:
8889
@echo ' cleaning sundials targets'
8990
$(RM) $(wildcard $(sort $(SUNDIALS_CVODES) $(SUNDIALS_IDAS) $(SUNDIALS_KINSOL) $(SUNDIALS_NVECSERIAL) $(SUNDIALS_TARGETS)))
90-
91+
endif
9192

9293
############################################################
9394
# TBB build rules
@@ -138,6 +139,11 @@ endif
138139
ifeq (Windows_NT, $(OS))
139140
ifeq ($(IS_UCRT),true)
140141
TBB_CXXFLAGS += -D_UCRT
142+
endif
143+
# TBB does not have assembly code for Windows ARM64, so we need to use GCC builtins
144+
ifneq ($(WINARM64),)
145+
TBB_CXXFLAGS += -DTBB_USE_GCC_BUILTINS
146+
CXXFLAGS_TBB += -DTBB_USE_GCC_BUILTINS
141147
endif
142148
SH_CHECK := $(shell command -v sh 2>/dev/null)
143149
ifdef SH_CHECK
@@ -169,11 +175,11 @@ endif
169175
$(TBB_BIN)/tbb.def: $(TBB_BIN)/tbb-make-check
170176
@mkdir -p $(TBB_BIN)
171177
touch $(TBB_BIN)/version_$(notdir $(TBB))
172-
tbb_root="$(TBB_RELATIVE_PATH)" CXX="$(CXX)" CC="$(TBB_CC)" LDFLAGS='$(LDFLAGS_TBB)' '$(MAKE)' -C "$(TBB_BIN)" -r -f "$(TBB_ABSOLUTE_PATH)/build/Makefile.tbb" compiler=$(TBB_CXX_TYPE) cfg=release stdver=c++1y CXXFLAGS="$(TBB_CXXFLAGS)"
178+
tbb_root="$(TBB_RELATIVE_PATH)" WINARM64="$(WINARM64)" CXX="$(CXX)" CC="$(TBB_CC)" LDFLAGS='$(LDFLAGS_TBB)' '$(MAKE)' -C "$(TBB_BIN)" -r -f "$(TBB_ABSOLUTE_PATH)/build/Makefile.tbb" compiler=$(TBB_CXX_TYPE) cfg=release stdver=c++1y CXXFLAGS="$(TBB_CXXFLAGS)"
173179

174180
$(TBB_BIN)/tbbmalloc.def: $(TBB_BIN)/tbb-make-check
175181
@mkdir -p $(TBB_BIN)
176-
tbb_root="$(TBB_RELATIVE_PATH)" CXX="$(CXX)" CC="$(TBB_CC)" LDFLAGS='$(LDFLAGS_TBB)' '$(MAKE)' -C "$(TBB_BIN)" -r -f "$(TBB_ABSOLUTE_PATH)/build/Makefile.tbbmalloc" compiler=$(TBB_CXX_TYPE) cfg=release stdver=c++1y malloc CXXFLAGS="$(TBB_CXXFLAGS)"
182+
tbb_root="$(TBB_RELATIVE_PATH)" WINARM64="$(WINARM64)" CXX="$(CXX)" CC="$(TBB_CC)" LDFLAGS='$(LDFLAGS_TBB)' '$(MAKE)' -C "$(TBB_BIN)" -r -f "$(TBB_ABSOLUTE_PATH)/build/Makefile.tbbmalloc" compiler=$(TBB_CXX_TYPE) cfg=release stdver=c++1y malloc CXXFLAGS="$(TBB_CXXFLAGS)"
177183

178184
$(TBB_BIN)/libtbb.dylib: $(TBB_BIN)/tbb.def
179185
$(TBB_BIN)/libtbbmalloc.dylib: $(TBB_BIN)/tbbmalloc.def

stan/math/prim/fun/value_of.hpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ inline auto value_of(const T& x) {
6767
* @param[in] M Matrix to be converted
6868
* @return Matrix of values
6969
**/
70-
template <typename EigMat, require_eigen_t<EigMat>* = nullptr,
70+
template <typename EigMat, require_eigen_dense_base_t<EigMat>* = nullptr,
7171
require_not_st_arithmetic<EigMat>* = nullptr>
7272
inline auto value_of(EigMat&& M) {
7373
return make_holder(
@@ -77,6 +77,28 @@ inline auto value_of(EigMat&& M) {
7777
std::forward<EigMat>(M));
7878
}
7979

80+
template <typename EigMat, require_eigen_sparse_base_t<EigMat>* = nullptr,
81+
require_not_st_arithmetic<EigMat>* = nullptr>
82+
inline auto value_of(EigMat&& M) {
83+
auto&& M_ref = to_ref(M);
84+
using scalar_t = decltype(value_of(std::declval<value_type_t<EigMat>>()));
85+
promote_scalar_t<scalar_t, plain_type_t<EigMat>> ret(M_ref.rows(),
86+
M_ref.cols());
87+
ret.reserve(M_ref.nonZeros());
88+
for (int k = 0; k < M_ref.outerSize(); ++k) {
89+
for (typename std::decay_t<EigMat>::InnerIterator it(M_ref, k); it; ++it) {
90+
ret.insert(it.row(), it.col()) = value_of(it.valueRef());
91+
}
92+
}
93+
ret.makeCompressed();
94+
return ret;
95+
}
96+
template <typename EigMat, require_eigen_sparse_base_t<EigMat>* = nullptr,
97+
require_st_arithmetic<EigMat>* = nullptr>
98+
inline auto value_of(EigMat&& M) {
99+
return std::forward<EigMat>(M);
100+
}
101+
80102
} // namespace math
81103
} // namespace stan
82104

stan/math/prim/meta/is_arena_matrix.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,17 @@ template <typename T>
2323
using require_arena_matrix_t = require_t<is_arena_matrix<std::decay_t<T>>>;
2424
/*! @} */
2525

26+
/*! \ingroup require_eigen_types */
27+
/*! \defgroup arena_matrix_types arena_matrix */
28+
/*! \addtogroup arena_matrix_types */
29+
/*! @{ */
30+
31+
/*! \brief Require type does not satisfy @ref is_arena_matrix */
32+
/*! @tparam T the type to check */
33+
template <typename T>
34+
using require_not_arena_matrix_t
35+
= require_t<bool_constant<!is_arena_matrix<std::decay_t<T>>::value>>;
36+
/*! @} */
37+
2638
} // namespace stan
2739
#endif

stan/math/prim/meta/is_eigen_dense_base.hpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,22 @@ using require_eigen_dense_base_t
3333
= require_t<is_eigen_dense_base<std::decay_t<T>>>;
3434
/*! @} */
3535

36+
/*! \ingroup require_eigens_types */
37+
/*! \defgroup eigen_dense_base_types eigen_dense_base_types */
38+
/*! \addtogroup eigen_dense_base_types */
39+
/*! @{ */
40+
41+
/*! \brief Require type satisfies @ref is_eigen_dense_base */
42+
/*! and value type satisfies `TypeCheck` */
43+
/*! @tparam TypeCheck The type trait to check the value type against */
44+
/*! @tparam Check The type to test @ref is_eigen_dense_base for and whose
45+
* @ref value_type is checked with `TypeCheck` */
46+
template <template <class...> class TypeCheck, class... Check>
47+
using require_eigen_dense_base_vt
48+
= require_t<container_type_check_base<is_eigen_dense_base, value_type_t,
49+
TypeCheck, Check...>>;
50+
/*! @} */
51+
3652
} // namespace stan
3753

3854
#endif

0 commit comments

Comments
 (0)