@@ -69,6 +69,34 @@ inline auto sum_to_zero_constrain(T&& y) {
6969 return arena_z;
7070}
7171
72+ /* *
73+ * Return a matrix that sums to zero over both the rows
74+ * and columns corresponding to the free matrix x.
75+ *
76+ * This is a linear transform, with no Jacobian.
77+ *
78+ * @tparam Mat type of the matrix
79+ * @param x Free matrix input of dimensionality (N - 1, M - 1).
80+ * @return Zero-sum matrix of dimensionality (N, M).
81+ */
82+ template <typename T, require_rev_matrix_t <T>* = nullptr ,
83+ require_not_t <is_rev_vector<T>>* = nullptr >
84+ inline auto sum_to_zero_constrain (T&& x) {
85+ using ret_type = plain_type_t <T>;
86+ if (unlikely (x.size () == 0 )) {
87+ return arena_t <ret_type>(Eigen::MatrixXd{{0 }});
88+ }
89+ auto arena_x = to_arena (std::forward<T>(x));
90+ arena_t <ret_type> arena_z = sum_to_zero_constrain (arena_x.val ());
91+
92+ reverse_pass_callback ([arena_x, arena_z]() mutable {
93+ const auto N = arena_x.rows ();
94+ const auto M = arena_x.cols ();
95+ });
96+
97+ return arena_z;
98+ }
99+
72100/* *
73101 * Return a vector with sum zero corresponding to the specified
74102 * free vector.
@@ -89,12 +117,12 @@ inline auto sum_to_zero_constrain(T&& y) {
89117 *
90118 * This is a linear transform, with no Jacobian.
91119 *
92- * @tparam Vec type of the vector
93- * @param y Free vector input of dimensionality K - 1 .
120+ * @tparam T type of the vector or matrix
121+ * @param y Free vector or matrix .
94122 * @param lp unused
95- * @return Zero-sum vector of dimensionality K.
123+ * @return Zero-sum vector or matrix which is one larger in each dimension
96124 */
97- template <typename T, typename Lp, require_rev_col_vector_t <T>* = nullptr >
125+ template <typename T, typename Lp, is_rev_matrix <T>* = nullptr >
98126inline auto sum_to_zero_constrain (T&& y, Lp& lp) {
99127 return sum_to_zero_constrain (std::forward<T>(y));
100128}
0 commit comments