Skip to content

Commit 7a9601d

Browse files
committed
Merge remote-tracking branch 'origin/develop' into feature/reverse-mode-move-semantics
2 parents c5f983a + e73651b commit 7a9601d

17 files changed

Lines changed: 500 additions & 87 deletions

lib/tbb_2020.3/STAN_CHANGES

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,10 @@ This file documents changes done for the stan-math project
66
- build/windows.inc patches for RTools make:
77
- L15 changed setting to use '?=', allowing override
88
- L25,L113,L114 added additional '/' to each cmd flag
9+
10+
- Support for Windows ARM64 with RTools:
11+
- build/Makefile.tbb
12+
- L94 Wrapped the use of `--version-script` export in conditional on non-WINARM64
13+
- build/windows.gcc.ino
14+
- L84 Wrapped the use of `-flifetime-dse` flag in conditional on non-WINARM64
15+

lib/tbb_2020.3/build/Makefile.tbb

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,11 @@ ifneq (,$(TBB.DEF))
9191
tbb.def: $(TBB.DEF) $(TBB.LST)
9292
$(CPLUS) $(PREPROC_ONLY) $< $(CPLUS_FLAGS) $(INCLUDES) > $@
9393

94-
LIB_LINK_FLAGS += $(EXPORT_KEY)tbb.def
94+
# LLVM on Windows doesn't need --version-script export
95+
# https://reviews.llvm.org/D63743
96+
ifeq (, $(WINARM64))
97+
LIB_LINK_FLAGS += $(EXPORT_KEY)tbb.def
98+
endif
9599
$(TBB.DLL): tbb.def
96100
endif
97101

lib/tbb_2020.3/build/windows.gcc.inc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,11 @@ endif
8080
# gcc 6.0 and later have -flifetime-dse option that controls
8181
# elimination of stores done outside the object lifetime
8282
ifeq (ok,$(call detect_js,/minversion gcc 6.0))
83-
# keep pre-contruction stores for zero initialization
84-
DSE_KEY = -flifetime-dse=1
83+
# Clang does not support -flifetime-dse
84+
ifeq (, $(WINARM64))
85+
# keep pre-contruction stores for zero initialization
86+
DSE_KEY = -flifetime-dse=1
87+
endif
8588
endif
8689

8790
ifeq ($(cfg), release)

make/compiler_flags

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ endif
1717

1818
## Set OS specific library filename extensions
1919
ifeq ($(OS),Windows_NT)
20+
WINARM64 := $(shell echo | $(CXX) -E -dM - | findstr __aarch64__)
2021
LIBRARY_SUFFIX ?= .dll
2122
endif
2223

@@ -271,8 +272,13 @@ CXXFLAGS_TBB ?= -I $(TBB_INC)
271272
else
272273
CXXFLAGS_TBB ?= -I $(TBB)/include
273274
endif
275+
LDFLAGS_TBB ?= -Wl,-L,"$(TBB_LIB)" -Wl,--disable-new-dtags
276+
277+
# Windows LLVM/Clang does not support -rpath, but is not needed on Windows anyway
278+
ifeq ($(WINARM64),)
279+
LDFLAGS_TBB += -Wl,-rpath,"$(TBB_LIB)"
280+
endif
274281

275-
LDFLAGS_TBB ?= -Wl,-L,"$(TBB_LIB)" -Wl,-rpath,"$(TBB_LIB)" -Wl,--disable-new-dtags
276282
LDLIBS_TBB ?= -ltbb
277283

278284
else
@@ -290,7 +296,12 @@ ifeq ($(OS),Linux)
290296
endif
291297

292298
CXXFLAGS_TBB ?= -I $(TBB)/include
293-
LDFLAGS_TBB ?= -Wl,-L,"$(TBB_BIN_ABSOLUTE_PATH)" -Wl,-rpath,"$(TBB_BIN_ABSOLUTE_PATH)" $(LDFLAGS_FLTO_FLTO) $(LDFLAGS_OPTIM_TBB)
299+
LDFLAGS_TBB ?= -Wl,-L,"$(TBB_BIN_ABSOLUTE_PATH)" $(LDFLAGS_FLTO_FLTO) $(LDFLAGS_OPTIM_TBB)
300+
301+
# Windows LLVM/Clang does not support -rpath, but is not needed on Windows anyway
302+
ifeq ($(WINARM64),)
303+
LDFLAGS_TBB += -Wl,-rpath,"$(TBB_BIN_ABSOLUTE_PATH)"
304+
endif
294305
LDLIBS_TBB ?= -ltbb
295306

296307
endif

make/libraries

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ CPPLINT ?= $(MATH)lib/cpplint_1.4.5
2323
# Fortran bindings which we do not need for stan-math. Thus these targets
2424
# are ignored here. This convention was introduced with 4.0.
2525
##
26+
ifndef SUNDIALS_TARGETS
2627

2728
SUNDIALS_CVODES := $(patsubst %.c,%.o,\
2829
$(wildcard $(SUNDIALS)/src/cvodes/*.c) \
@@ -87,7 +88,7 @@ $(STAN_SUNDIALS_HEADERS) : $(SUNDIALS_TARGETS)
8788
clean-sundials:
8889
@echo ' cleaning sundials targets'
8990
$(RM) $(wildcard $(sort $(SUNDIALS_CVODES) $(SUNDIALS_IDAS) $(SUNDIALS_KINSOL) $(SUNDIALS_NVECSERIAL) $(SUNDIALS_TARGETS)))
90-
91+
endif
9192

9293
############################################################
9394
# TBB build rules
@@ -138,6 +139,11 @@ endif
138139
ifeq (Windows_NT, $(OS))
139140
ifeq ($(IS_UCRT),true)
140141
TBB_CXXFLAGS += -D_UCRT
142+
endif
143+
# TBB does not have assembly code for Windows ARM64, so we need to use GCC builtins
144+
ifneq ($(WINARM64),)
145+
TBB_CXXFLAGS += -DTBB_USE_GCC_BUILTINS
146+
CXXFLAGS_TBB += -DTBB_USE_GCC_BUILTINS
141147
endif
142148
SH_CHECK := $(shell command -v sh 2>/dev/null)
143149
ifdef SH_CHECK
@@ -169,11 +175,11 @@ endif
169175
$(TBB_BIN)/tbb.def: $(TBB_BIN)/tbb-make-check
170176
@mkdir -p $(TBB_BIN)
171177
touch $(TBB_BIN)/version_$(notdir $(TBB))
172-
tbb_root="$(TBB_RELATIVE_PATH)" CXX="$(CXX)" CC="$(TBB_CC)" LDFLAGS='$(LDFLAGS_TBB)' '$(MAKE)' -C "$(TBB_BIN)" -r -f "$(TBB_ABSOLUTE_PATH)/build/Makefile.tbb" compiler=$(TBB_CXX_TYPE) cfg=release stdver=c++1y CXXFLAGS="$(TBB_CXXFLAGS)"
178+
tbb_root="$(TBB_RELATIVE_PATH)" WINARM64="$(WINARM64)" CXX="$(CXX)" CC="$(TBB_CC)" LDFLAGS='$(LDFLAGS_TBB)' '$(MAKE)' -C "$(TBB_BIN)" -r -f "$(TBB_ABSOLUTE_PATH)/build/Makefile.tbb" compiler=$(TBB_CXX_TYPE) cfg=release stdver=c++1y CXXFLAGS="$(TBB_CXXFLAGS)"
173179

174180
$(TBB_BIN)/tbbmalloc.def: $(TBB_BIN)/tbb-make-check
175181
@mkdir -p $(TBB_BIN)
176-
tbb_root="$(TBB_RELATIVE_PATH)" CXX="$(CXX)" CC="$(TBB_CC)" LDFLAGS='$(LDFLAGS_TBB)' '$(MAKE)' -C "$(TBB_BIN)" -r -f "$(TBB_ABSOLUTE_PATH)/build/Makefile.tbbmalloc" compiler=$(TBB_CXX_TYPE) cfg=release stdver=c++1y malloc CXXFLAGS="$(TBB_CXXFLAGS)"
182+
tbb_root="$(TBB_RELATIVE_PATH)" WINARM64="$(WINARM64)" CXX="$(CXX)" CC="$(TBB_CC)" LDFLAGS='$(LDFLAGS_TBB)' '$(MAKE)' -C "$(TBB_BIN)" -r -f "$(TBB_ABSOLUTE_PATH)/build/Makefile.tbbmalloc" compiler=$(TBB_CXX_TYPE) cfg=release stdver=c++1y malloc CXXFLAGS="$(TBB_CXXFLAGS)"
177183

178184
$(TBB_BIN)/libtbb.dylib: $(TBB_BIN)/tbb.def
179185
$(TBB_BIN)/libtbbmalloc.dylib: $(TBB_BIN)/tbbmalloc.def

stan/math/prim/fun/value_of.hpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ inline auto value_of(const T& x) {
6767
* @param[in] M Matrix to be converted
6868
* @return Matrix of values
6969
**/
70-
template <typename EigMat, require_eigen_t<EigMat>* = nullptr,
70+
template <typename EigMat, require_eigen_dense_base_t<EigMat>* = nullptr,
7171
require_not_st_arithmetic<EigMat>* = nullptr>
7272
inline auto value_of(EigMat&& M) {
7373
return make_holder(
@@ -77,6 +77,28 @@ inline auto value_of(EigMat&& M) {
7777
std::forward<EigMat>(M));
7878
}
7979

80+
template <typename EigMat, require_eigen_sparse_base_t<EigMat>* = nullptr,
81+
require_not_st_arithmetic<EigMat>* = nullptr>
82+
inline auto value_of(EigMat&& M) {
83+
auto&& M_ref = to_ref(M);
84+
using scalar_t = decltype(value_of(std::declval<value_type_t<EigMat>>()));
85+
promote_scalar_t<scalar_t, plain_type_t<EigMat>> ret(M_ref.rows(),
86+
M_ref.cols());
87+
ret.reserve(M_ref.nonZeros());
88+
for (int k = 0; k < M_ref.outerSize(); ++k) {
89+
for (typename std::decay_t<EigMat>::InnerIterator it(M_ref, k); it; ++it) {
90+
ret.insert(it.row(), it.col()) = value_of(it.valueRef());
91+
}
92+
}
93+
ret.makeCompressed();
94+
return ret;
95+
}
96+
template <typename EigMat, require_eigen_sparse_base_t<EigMat>* = nullptr,
97+
require_st_arithmetic<EigMat>* = nullptr>
98+
inline auto value_of(EigMat&& M) {
99+
return std::forward<EigMat>(M);
100+
}
101+
80102
} // namespace math
81103
} // namespace stan
82104

stan/math/prim/meta/is_eigen_dense_base.hpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,22 @@ using require_eigen_dense_base_t
3333
= require_t<is_eigen_dense_base<std::decay_t<T>>>;
3434
/*! @} */
3535

36+
/*! \ingroup require_eigens_types */
37+
/*! \defgroup eigen_dense_base_types eigen_dense_base_types */
38+
/*! \addtogroup eigen_dense_base_types */
39+
/*! @{ */
40+
41+
/*! \brief Require type satisfies @ref is_eigen_dense_base */
42+
/*! and value type satisfies `TypeCheck` */
43+
/*! @tparam TypeCheck The type trait to check the value type against */
44+
/*! @tparam Check The type to test @ref is_eigen_dense_base for and whose
45+
* @ref value_type is checked with `TypeCheck` */
46+
template <template <class...> class TypeCheck, class... Check>
47+
using require_eigen_dense_base_vt
48+
= require_t<container_type_check_base<is_eigen_dense_base, value_type_t,
49+
TypeCheck, Check...>>;
50+
/*! @} */
51+
3652
} // namespace stan
3753

3854
#endif

stan/math/prim/meta/promote_scalar_type.hpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#include <stan/math/prim/fun/Eigen.hpp>
55
#include <stan/math/prim/meta/is_eigen.hpp>
66
#include <stan/math/prim/meta/is_var.hpp>
7+
#include <stan/math/prim/meta/is_eigen_dense_base.hpp>
8+
#include <stan/math/prim/meta/is_eigen_sparse_base.hpp>
79
#include <vector>
810

911
namespace stan {
@@ -80,7 +82,7 @@ struct promote_scalar_type<T, S,
8082
* @tparam S input matrix type
8183
*/
8284
template <typename T, typename S>
83-
struct promote_scalar_type<T, S, require_eigen_t<S>> {
85+
struct promote_scalar_type<T, S, require_eigen_dense_base_t<S>> {
8486
/**
8587
* The promoted type.
8688
*/
@@ -93,6 +95,16 @@ struct promote_scalar_type<T, S, require_eigen_t<S>> {
9395
S::RowsAtCompileTime, S::ColsAtCompileTime>>::type;
9496
};
9597

98+
template <typename T, typename S>
99+
struct promote_scalar_type<T, S, require_eigen_sparse_base_t<S>> {
100+
/**
101+
* The promoted type.
102+
*/
103+
using type = Eigen::SparseMatrix<
104+
typename promote_scalar_type<T, typename S::Scalar>::type, S::Options,
105+
typename S::StorageIndex>;
106+
};
107+
96108
template <typename... PromotionScalars, typename... UnPromotedTypes>
97109
struct promote_scalar_type<std::tuple<PromotionScalars...>,
98110
std::tuple<UnPromotedTypes...>> {

stan/math/rev/core/arena_matrix.hpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#include <stan/math/rev/core/chainable_object.hpp>
88
#include <stan/math/rev/core/var_value_fwd_declare.hpp>
99
#include <stan/math/prim/fun/to_ref.hpp>
10-
1110
namespace stan {
1211
namespace math {
1312

@@ -269,8 +268,10 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
269268
*/
270269
arena_matrix(const arena_matrix<MatrixType>& other)
271270
: Base::Map(other.rows(), other.cols(), other.nonZeros(),
272-
other.outerIndexPtr(), other.innerIndexPtr(),
273-
other.valuePtr(), other.innernonZeroPtr()) {}
271+
const_cast<StorageIndex*>(other.outerIndexPtr()),
272+
const_cast<StorageIndex*>(other.innerIndexPtr()),
273+
const_cast<Scalar*>(other.valuePtr()),
274+
const_cast<StorageIndex*>(other.innerNonZeroPtr())) {}
274275
/**
275276
* Move constructor.
276277
* @note Since the memory for the arena matrix sits in Stan's memory arena all
@@ -279,8 +280,10 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
279280
*/
280281
arena_matrix(arena_matrix<MatrixType>&& other)
281282
: Base::Map(other.rows(), other.cols(), other.nonZeros(),
282-
other.outerIndexPtr(), other.innerIndexPtr(),
283-
other.valuePtr(), other.innerNonZeroPtr()) {}
283+
const_cast<StorageIndex*>(other.outerIndexPtr()),
284+
const_cast<StorageIndex*>(other.innerIndexPtr()),
285+
const_cast<Scalar*>(other.valuePtr()),
286+
const_cast<StorageIndex*>(other.innerNonZeroPtr())) {}
284287
/**
285288
* Copy constructor. No actual copy is performed
286289
* @note Since the memory for the arena matrix sits in Stan's memory arena all
@@ -289,8 +292,10 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
289292
*/
290293
arena_matrix(arena_matrix<MatrixType>& other)
291294
: Base::Map(other.rows(), other.cols(), other.nonZeros(),
292-
other.outerIndexPtr(), other.innerIndexPtr(),
293-
other.valuePtr(), other.innerNonZeroPtr()) {}
295+
const_cast<StorageIndex*>(other.outerIndexPtr()),
296+
const_cast<StorageIndex*>(other.innerIndexPtr()),
297+
const_cast<Scalar*>(other.valuePtr()),
298+
const_cast<StorageIndex*>(other.innerNonZeroPtr())) {}
294299

295300
// without this using, compiler prefers combination of implicit construction
296301
// and copy assignment to the inherited operator when assigned an expression
@@ -303,7 +308,8 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
303308
* @return `*this`
304309
*/
305310
template <typename ArenaMatrix,
306-
require_same_t<ArenaMatrix, arena_matrix<MatrixType>>* = nullptr>
311+
require_same_t<std::decay_t<ArenaMatrix>,
312+
arena_matrix<MatrixType>>* = nullptr>
307313
arena_matrix& operator=(ArenaMatrix&& other) {
308314
// placement new changes what data map points to - there is no allocation
309315
new (this) Base(other.rows(), other.cols(), other.nonZeros(),
@@ -324,7 +330,7 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
324330
template <typename Expr,
325331
require_not_same_t<Expr, arena_matrix<MatrixType>>* = nullptr>
326332
arena_matrix& operator=(Expr&& expr) {
327-
*this = arena_matrix(std::forward<Expr>(expr));
333+
new (this) arena_matrix(std::forward<Expr>(expr));
328334
return *this;
329335
}
330336

stan/math/rev/core/var.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,18 @@ class var_value<T, internal::require_matrix_var_value<T>> {
418418
});
419419
}
420420

421+
/**
422+
* Construct a `var_value` with premade @ref arena_matrix types.
423+
* The values and adjoint matrices passed here will be shallow copied.
424+
* @tparam S type of the value in the `var_value` to assing
425+
* @param val The value matrix to go into the vari
426+
* @param adj the adjoint matrix to go into the vari
427+
*/
428+
template <typename S, typename T_ = T,
429+
require_assignable_t<value_type, S>* = nullptr,
430+
require_arena_matrix_t<S>* = nullptr>
431+
var_value(const S& val, const S& adj) : vi_(new vari_type(val, adj)) {}
432+
421433
/**
422434
* Construct a variable from a pointer to a variable implementation.
423435
* @param vi A vari_value pointer.

0 commit comments

Comments
 (0)