Skip to content

Commit 80d22fb

Browse files
authored
Merge pull request #3095 from stan-dev/feature/stochastic-matrix-err-check
adds check functions for stochastic row and column matrices
2 parents 9052db8 + 72c34c0 commit 80d22fb

5 files changed

Lines changed: 497 additions & 0 deletions

File tree

stan/math/prim/err.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
#include <stan/math/prim/err/check_size_match.hpp>
4040
#include <stan/math/prim/err/check_square.hpp>
4141
#include <stan/math/prim/err/check_std_vector_index.hpp>
42+
#include <stan/math/prim/err/check_stochastic_column.hpp>
43+
#include <stan/math/prim/err/check_stochastic_row.hpp>
4244
#include <stan/math/prim/err/check_symmetric.hpp>
4345
#include <stan/math/prim/err/check_unit_vector.hpp>
4446
#include <stan/math/prim/err/check_vector.hpp>
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#ifndef STAN_MATH_PRIM_ERR_CHECK_STOCHASTIC_COLUMN_HPP
2+
#define STAN_MATH_PRIM_ERR_CHECK_STOCHASTIC_COLUMN_HPP
3+
4+
#include <stan/math/prim/fun/Eigen.hpp>
5+
#include <stan/math/prim/meta.hpp>
6+
#include <stan/math/prim/err/check_nonzero_size.hpp>
7+
#include <stan/math/prim/err/constraint_tolerance.hpp>
8+
#include <stan/math/prim/err/make_iter_name.hpp>
9+
#include <stan/math/prim/err/throw_domain_error.hpp>
10+
#include <stan/math/prim/fun/to_ref.hpp>
11+
#include <stan/math/prim/fun/value_of_rec.hpp>
12+
#include <sstream>
13+
#include <string>
14+
#include <iostream>
15+
namespace stan {
16+
namespace math {
17+
18+
/**
19+
* Throw an exception if the specified matrix is not a column stochastic matrix.
20+
* To be a column stochastic matrix, all the values in each column must be
21+
* greater than or equal to 0 and the values must sum to 1. A valid column
22+
* stochastic matrix is one where the sum of the elements by column is equal
23+
* to 1. This function tests that the sum is within the tolerance specified by
24+
* `CONSTRAINT_TOLERANCE`. This function only accepts Eigen matrices, statically
25+
* typed vectors, not general matrices with 1 column.
26+
* @tparam T A type inheriting from `Eigen::EigenBase`
27+
* @param function Function name (for error messages)
28+
* @param name Variable name (for error messages)
29+
* @param theta Matrix to test
30+
* @throw `std::invalid_argument` if `theta` is a 0-vector
31+
* @throw `std::domain_error` if the vector is not a column stochastic matrix or
32+
* if any element is `NaN`
33+
*/
34+
template <typename T, require_matrix_t<T>* = nullptr>
35+
void check_stochastic_column(const char* function, const char* name,
36+
const T& theta) {
37+
using std::fabs;
38+
check_nonzero_size(function, name, theta);
39+
auto&& theta_ref = to_ref(value_of_rec(theta));
40+
for (Eigen::Index j = 0; j < theta_ref.cols(); ++j) {
41+
value_type_t<decltype(theta_ref)> vec_sum = 0.0;
42+
for (Eigen::Index i = 0; i < theta_ref.rows(); ++i) {
43+
if (!(theta_ref.coeff(i, j) >= 0)) {
44+
[&]() STAN_COLD_PATH {
45+
std::ostringstream msg;
46+
msg << "is not a valid column stochastic matrix. " << name << "["
47+
<< std::to_string(i + stan::error_index::value) << ", "
48+
<< std::to_string(i + stan::error_index::value) << "]"
49+
<< " = ";
50+
std::string msg_str(msg.str());
51+
throw_domain_error(function, name, theta_ref.coeff(i, j),
52+
msg_str.c_str(),
53+
", but should be greater than or equal to 0");
54+
}();
55+
}
56+
vec_sum += theta_ref.coeff(i, j);
57+
}
58+
if (!(fabs(1.0 - vec_sum) <= CONSTRAINT_TOLERANCE)) {
59+
[&]() STAN_COLD_PATH {
60+
std::stringstream msg;
61+
msg << "is not a valid column stochastic matrix.";
62+
msg.precision(10);
63+
msg << " sum(" << name << "[:, " << std::to_string(j + 1)
64+
<< "]) = " << vec_sum << ", but should be ";
65+
std::string msg_str(msg.str());
66+
throw_domain_error(function, name, 1.0, msg_str.c_str());
67+
}();
68+
}
69+
}
70+
}
71+
72+
/**
73+
* Throw an exception if the specified matrices in a standard vector are not a
74+
* column stochastic matrix. To be a column stochastic matrix, all the values in
75+
* each column must be greater than or equal to 0 and the values must sum to 1.
76+
* A valid column stochastic matrix is one where the sum of the elements by
77+
* column is equal to 1. This function tests that the sum is within the
78+
* tolerance specified by `CONSTRAINT_TOLERANCE`. This function only accepts
79+
* Eigen matrices, statically typed vectors, not general matrices with 1 column.
80+
* @tparam T A type inheriting from `Eigen::EigenBase`
81+
* @param function Function name (for error messages)
82+
* @param name Variable name (for error messages)
83+
* @param theta Matrix to test
84+
* @throw `std::invalid_argument` if `theta` is a 0-vector
85+
* @throw `std::domain_error` if the vector's matrices are not column stochastic
86+
* matrices or if any element is `NaN`
87+
*/
88+
template <typename T, require_std_vector_t<T>* = nullptr>
89+
void check_stochastic_column(const char* function, const char* name,
90+
const T& theta) {
91+
for (size_t i = 0; i < theta.size(); ++i) {
92+
check_stochastic_column(function, internal::make_iter_name(name, i).c_str(),
93+
theta[i]);
94+
}
95+
}
96+
97+
} // namespace math
98+
} // namespace stan
99+
#endif
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#ifndef STAN_MATH_PRIM_ERR_CHECK_STOCHASTIC_ROW_HPP
2+
#define STAN_MATH_PRIM_ERR_CHECK_STOCHASTIC_ROW_HPP
3+
4+
#include <stan/math/prim/fun/Eigen.hpp>
5+
#include <stan/math/prim/meta.hpp>
6+
#include <stan/math/prim/err/check_nonzero_size.hpp>
7+
#include <stan/math/prim/err/constraint_tolerance.hpp>
8+
#include <stan/math/prim/err/make_iter_name.hpp>
9+
#include <stan/math/prim/err/throw_domain_error.hpp>
10+
#include <stan/math/prim/fun/to_ref.hpp>
11+
#include <stan/math/prim/fun/value_of_rec.hpp>
12+
#include <sstream>
13+
#include <string>
14+
15+
namespace stan {
16+
namespace math {
17+
18+
/**
19+
* Throw an exception if the specified matrix is not a row stochastic matrix. To
20+
* be a row stochastic matrix, all the values in each row must be greater than
21+
* or equal to 0 and the values must sum to 1. A valid row stochastic matrix is
22+
* one where the sum of the elements by row is equal to 1. This function tests
23+
* that the sum is within the tolerance specified by `CONSTRAINT_TOLERANCE`.
24+
* This function only accepts Eigen matrices, statically typed vectors, not
25+
* general matrices with 1 column.
26+
* @tparam T A type inheriting from `Eigen::EigenBase`
27+
* @param function Function name (for error messages)
28+
* @param name Variable name (for error messages)
29+
* @param theta Matrix to test
30+
* @throw `std::invalid_argument` if `theta` is a 0-vector
31+
* @throw `std::domain_error` if the vector is not a row stochastic matrix or if
32+
* any element is `NaN`
33+
*/
34+
template <typename T, require_matrix_t<T>* = nullptr>
35+
void check_stochastic_row(const char* function, const char* name,
36+
const T& theta) {
37+
using std::fabs;
38+
check_nonzero_size(function, name, theta);
39+
auto&& theta_ref = to_ref(value_of_rec(theta));
40+
for (Eigen::Index i = 0; i < theta_ref.rows(); ++i) {
41+
value_type_t<decltype(theta_ref)> vec_sum = 0.0;
42+
for (Eigen::Index j = 0; j < theta_ref.cols(); ++j) {
43+
if (!(theta_ref.coeff(i, j) >= 0)) {
44+
[&]() STAN_COLD_PATH {
45+
std::ostringstream msg;
46+
msg << "is not a valid row stochastic matrix. " << name << "["
47+
<< std::to_string(i + stan::error_index::value) << ", "
48+
<< std::to_string(i + stan::error_index::value) << "]"
49+
<< " = ";
50+
std::string msg_str(msg.str());
51+
throw_domain_error(function, name, theta_ref.coeff(i, j),
52+
msg_str.c_str(),
53+
", but should be greater than or equal to 0");
54+
}();
55+
}
56+
vec_sum += theta_ref.coeff(i, j);
57+
}
58+
if (!(fabs(1.0 - vec_sum) <= CONSTRAINT_TOLERANCE)) {
59+
[&]() STAN_COLD_PATH {
60+
std::stringstream msg;
61+
msg << "is not a valid row stochastic matrix.";
62+
msg.precision(10);
63+
msg << " sum(" << name << "[" << std::to_string(i + 1)
64+
<< ",:]) = " << vec_sum << ", but should be ";
65+
std::string msg_str(msg.str());
66+
throw_domain_error(function, name, 1.0, msg_str.c_str());
67+
}();
68+
}
69+
}
70+
}
71+
72+
/**
73+
* Throw an exception if the specified matrices in a standard vector are not a
74+
* row stochastic matrix. To be a row stochastic matrix, all the values in each
75+
* row must be greater than or equal to 0 and the values must sum to 1. A valid
76+
* row stochastic matrix is one where the sum of the elements by row is equal
77+
* to 1. This function tests that the sum is within the tolerance specified by
78+
* `CONSTRAINT_TOLERANCE`. This function only accepts Eigen matrices, statically
79+
* typed vectors, not general matrices with 1 column.
80+
* @tparam T A type inheriting from `Eigen::EigenBase`
81+
* @param function Function name (for error messages)
82+
* @param name Variable name (for error messages)
83+
* @param theta Matrix to test
84+
* @throw `std::invalid_argument` if `theta` is a 0-vector
85+
* @throw `std::domain_error` if the standard vector's matrices are not row
86+
* stochastic matrix or if any element is `NaN`
87+
*/
88+
template <typename T, require_std_vector_t<T>* = nullptr>
89+
void check_stochastic_row(const char* function, const char* name,
90+
const T& theta) {
91+
for (size_t i = 0; i < theta.size(); ++i) {
92+
check_stochastic_row(function, internal::make_iter_name(name, i).c_str(),
93+
theta[i]);
94+
}
95+
}
96+
97+
} // namespace math
98+
} // namespace stan
99+
#endif
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
#include <stan/math/prim.hpp>
2+
#include <gtest/gtest.h>
3+
#include <test/unit/util.hpp>
4+
#include <limits>
5+
#include <string>
6+
7+
TEST(ErrorHandlingMatrix, checkStochasticColumn) {
8+
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> y_vec(2, 2);
9+
std::vector<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>> y{
10+
y_vec, y_vec, y_vec};
11+
for (auto& y_i : y) {
12+
y_i << 0.5, 0.5, 0.5, 0.5;
13+
}
14+
15+
EXPECT_NO_THROW(
16+
stan::math::check_stochastic_column("checkStochasticColumn", "y", y));
17+
18+
for (auto& y_i : y) {
19+
y_i(0, 1) = 0.55;
20+
}
21+
EXPECT_THROW(
22+
stan::math::check_stochastic_column("checkStochasticColumn", "y", y),
23+
std::domain_error);
24+
}
25+
26+
TEST(ErrorHandlingMatrix, checkStochasticColumn_message_negative_value) {
27+
std::string message;
28+
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> y_vec(3, 3);
29+
y_vec.setZero();
30+
std::vector<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>> y{
31+
y_vec, y_vec, y_vec};
32+
for (auto& y_i : y) {
33+
y_i(0, 0) = -0.1;
34+
y_i(1, 0) = 1.1;
35+
y_i(0, 1) = -0.1;
36+
y_i(1, 1) = 1.1;
37+
}
38+
39+
try {
40+
stan::math::check_stochastic_column("checkStochasticColumn", "y", y);
41+
FAIL() << "should have thrown";
42+
} catch (std::domain_error& e) {
43+
message = e.what();
44+
} catch (...) {
45+
FAIL() << "threw the wrong error";
46+
}
47+
48+
EXPECT_TRUE(std::string::npos
49+
!= message.find(" y[1] is not a valid column stochastic matrix"))
50+
<< "Found: " << message;
51+
52+
EXPECT_TRUE(std::string::npos != message.find("y[1][1, 1] = -0.1"))
53+
<< "Found: " << message;
54+
55+
for (auto& y_i : y) {
56+
y_i.setZero();
57+
y_i(0, 0) = 0.1;
58+
y_i(1, 0) = 0.1;
59+
y_i(2, 0) = 1.0;
60+
}
61+
try {
62+
stan::math::check_stochastic_column("checkStochasticColumn", "y", y);
63+
FAIL() << "should have thrown";
64+
} catch (std::domain_error& e) {
65+
message = e.what();
66+
} catch (...) {
67+
FAIL() << "threw the wrong error";
68+
}
69+
70+
EXPECT_TRUE(std::string::npos
71+
!= message.find(" y[1] is not a valid column stochastic matrix"))
72+
<< "Found: " << message;
73+
74+
EXPECT_TRUE(std::string::npos != message.find("sum(y[1][:, 1]) = 1.2"))
75+
<< "Found: " << message;
76+
}
77+
78+
TEST(ErrorHandlingMatrix, checkStochasticColumn_message_sum) {
79+
std::string message;
80+
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> y_vec(10, 10);
81+
y_vec.setZero();
82+
std::vector<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>> y{
83+
y_vec, y_vec, y_vec};
84+
for (auto& y_i : y) {
85+
y_i(3, 0) = 0.9;
86+
}
87+
88+
try {
89+
stan::math::check_stochastic_column("checkStochasticColumn", "y", y);
90+
FAIL() << "should have thrown";
91+
} catch (std::domain_error& e) {
92+
message = e.what();
93+
} catch (...) {
94+
FAIL() << "threw the wrong error";
95+
}
96+
97+
EXPECT_TRUE(std::string::npos
98+
!= message.find(" y[1] is not a valid column stochastic matrix"))
99+
<< message;
100+
101+
EXPECT_TRUE(std::string::npos != message.find("sum(y[1][:, 1]) = 0.9"))
102+
<< message;
103+
}
104+
105+
TEST(ErrorHandlingMatrix, checkStochasticColumn_message_length) {
106+
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> y_vec(0, 0);
107+
std::vector<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>> y{
108+
y_vec, y_vec, y_vec};
109+
110+
using stan::math::check_stochastic_column;
111+
112+
EXPECT_THROW_MSG(check_stochastic_column("checkStochasticColumn", "y", y),
113+
std::invalid_argument,
114+
"y[1] has size 0, but must have a non-zero size");
115+
}
116+
117+
TEST(ErrorHandlingMatrix, checkStochasticColumn_nan) {
118+
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> y_vec(2, 2);
119+
std::vector<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>> y{
120+
y_vec, y_vec, y_vec};
121+
constexpr double nan = std::numeric_limits<double>::quiet_NaN();
122+
for (auto& y_i : y) {
123+
y_i << nan, 0.5, nan, 0.5;
124+
}
125+
126+
EXPECT_THROW(
127+
stan::math::check_stochastic_column("checkStochasticColumn", "y", y),
128+
std::domain_error);
129+
130+
for (auto& y_i : y) {
131+
y_i(0, 1) = 0.55;
132+
}
133+
EXPECT_THROW(
134+
stan::math::check_stochastic_column("checkStochasticColumn", "y", y),
135+
std::domain_error);
136+
137+
for (auto& y_i : y) {
138+
y_i(0, 0) = 0.5;
139+
y_i(0, 1) = nan;
140+
}
141+
EXPECT_THROW(
142+
stan::math::check_stochastic_column("checkStochasticColumn", "y", y),
143+
std::domain_error);
144+
145+
for (auto& y_i : y) {
146+
y_i(0, 0) = nan;
147+
}
148+
EXPECT_THROW(
149+
stan::math::check_stochastic_column("checkStochasticColumn", "y", y),
150+
std::domain_error);
151+
}

0 commit comments

Comments
 (0)