Skip to content

Commit 27116c0

Browse files
committed
Nested array support
1 parent 50f1313 commit 27116c0

2 files changed

Lines changed: 21 additions & 8 deletions

File tree

stan/math/prim/fun/size_mvt.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ int64_t size_mvt(const ScalarT& /* unused */) {
2626
throw std::invalid_argument("size_mvt passed to an unrecognized type.");
2727
}
2828

29+
template <typename ScalarT, require_stan_scalar_t<ScalarT>* = nullptr>
30+
int64_t size_mvt(const std::vector<ScalarT>& /* unused */) {
31+
return 1;
32+
}
33+
2934
template <typename MatrixT, require_matrix_t<MatrixT>* = nullptr>
3035
int64_t size_mvt(const MatrixT& /* unused */) {
3136
return 1;
@@ -36,6 +41,11 @@ int64_t size_mvt(const std::vector<MatrixT>& x) {
3641
return x.size();
3742
}
3843

44+
template <typename StdVectorT, require_std_vector_t<StdVectorT>* = nullptr>
45+
int64_t size_mvt(const std::vector<StdVectorT>& x) {
46+
return x.size();
47+
}
48+
3949
} // namespace math
4050
} // namespace stan
4151
#endif

stan/math/prim/fun/vector_seq_view.hpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@
66
#include <vector>
77

88
namespace stan {
9+
namespace internal {
10+
template <typename T>
11+
using is_matrix_or_std_vector
12+
= math::disjunction<is_matrix<T>, is_std_vector<T>>;
13+
14+
template <typename T>
15+
using is_scalar_container = math::conjunction<is_matrix_or_std_vector<T>,
16+
is_stan_scalar<value_type_t<T>>>;
17+
}
918

1019
/**
1120
* This class provides a low-cost wrapper for situations where you either need
@@ -33,7 +42,7 @@ class vector_seq_view;
3342
* @tparam T the type of the underlying Vector
3443
*/
3544
template <typename T>
36-
class vector_seq_view<T, require_matrix_t<T>> {
45+
class vector_seq_view<T, require_t<internal::is_scalar_container<T>>> {
3746
public:
3847
explicit vector_seq_view(const T& m) : m_(m) {}
3948
static constexpr auto size() { return 1; }
@@ -46,19 +55,13 @@ class vector_seq_view<T, require_matrix_t<T>> {
4655

4756
template <typename C = T, require_st_autodiff<C>* = nullptr>
4857
inline auto val(size_t /* i */) const noexcept {
49-
return m_.val();
58+
return value_of(m_);
5059
}
5160

5261
private:
5362
const ref_type_t<T> m_;
5463
};
5564

56-
namespace internal {
57-
template <typename T>
58-
using is_matrix_or_std_vector
59-
= math::disjunction<is_matrix<T>, is_std_vector<T>>;
60-
}
61-
6265
/**
6366
* This class provides a low-cost wrapper for situations where you either need
6467
* an Eigen Vector or RowVector or a std::vector of them and you want to be

0 commit comments

Comments
 (0)