Skip to content

Commit 3f4a904

Browse files
committed
adds check functions for stochastic row and column matrices
1 parent 9052db8 commit 3f4a904

5 files changed

Lines changed: 456 additions & 0 deletions

File tree

stan/math/prim/err.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
#include <stan/math/prim/err/check_size_match.hpp>
4040
#include <stan/math/prim/err/check_square.hpp>
4141
#include <stan/math/prim/err/check_std_vector_index.hpp>
42+
#include <stan/math/prim/err/check_stochastic_column.hpp>
43+
#include <stan/math/prim/err/check_stochastic_row.hpp>
4244
#include <stan/math/prim/err/check_symmetric.hpp>
4345
#include <stan/math/prim/err/check_unit_vector.hpp>
4446
#include <stan/math/prim/err/check_vector.hpp>
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
#ifndef STAN_MATH_PRIM_ERR_CHECK_STOCHASTIC_COLUMN_HPP
2+
#define STAN_MATH_PRIM_ERR_CHECK_STOCHASTIC_COLUMN_HPP
3+
4+
#include <stan/math/prim/fun/Eigen.hpp>
5+
#include <stan/math/prim/meta.hpp>
6+
#include <stan/math/prim/err/check_nonzero_size.hpp>
7+
#include <stan/math/prim/err/constraint_tolerance.hpp>
8+
#include <stan/math/prim/err/make_iter_name.hpp>
9+
#include <stan/math/prim/err/throw_domain_error.hpp>
10+
#include <stan/math/prim/fun/to_ref.hpp>
11+
#include <stan/math/prim/fun/value_of_rec.hpp>
12+
#include <sstream>
13+
#include <string>
14+
#include <iostream>
15+
namespace stan {
16+
namespace math {
17+
18+
/**
19+
* Throw an exception if the specified matrix is not a column stochastic matrix. To be a
20+
* column stochastic matrix, all the values in each column must be greater than or equal to 0 and the values must sum to 1. A
21+
* valid column stochastic matrix is one where the sum of the elements by column is equal to 1. This
22+
* function tests that the sum is within the tolerance specified by
23+
* `CONSTRAINT_TOLERANCE`. This function only accepts Eigen matrices, statically
24+
* typed vectors, not general matrices with 1 column.
25+
* @tparam T A type inheriting from `Eigen::EigenBase`
26+
* @param function Function name (for error messages)
27+
* @param name Variable name (for error messages)
28+
* @param theta Matrix to test
29+
* @throw `std::invalid_argument` if `theta` is a 0-vector
30+
* @throw `std::domain_error` if the vector is not a column stochastic matrix or if any element
31+
* is `NaN`
32+
*/
33+
template <typename T, require_matrix_t<T>* = nullptr>
34+
void check_stochastic_column(const char* function, const char* name, const T& theta) {
35+
using std::fabs;
36+
check_nonzero_size(function, name, theta);
37+
auto&& theta_ref = to_ref(value_of_rec(theta));
38+
for (Eigen::Index j = 0; j < theta_ref.cols(); ++j) {
39+
value_type_t<decltype(theta_ref)> vec_sum = 0.0;
40+
for (Eigen::Index i = 0; i < theta_ref.rows(); ++i) {
41+
if (!(theta_ref.coeff(i, j) >= 0)) {
42+
[&]() STAN_COLD_PATH {
43+
std::ostringstream msg;
44+
msg << "is not a valid column stochastic matrix. " << name << "["
45+
<< std::to_string(i + stan::error_index::value) <<
46+
", " << std::to_string(i + stan::error_index::value) << "]"
47+
<< " = ";
48+
std::string msg_str(msg.str());
49+
throw_domain_error(function, name, theta_ref.coeff(i, j), msg_str.c_str(),
50+
", but should be greater than or equal to 0");
51+
}();
52+
}
53+
vec_sum += theta_ref.coeff(i, j);
54+
}
55+
if (!(fabs(1.0 - vec_sum) <= CONSTRAINT_TOLERANCE)) {
56+
[&]() STAN_COLD_PATH {
57+
std::stringstream msg;
58+
msg << "is not a valid column stochastic matrix.";
59+
msg.precision(10);
60+
msg << " sum(" << name << "[:, "<< std::to_string(j + 1) << "]) = " << vec_sum << ", but should be ";
61+
std::string msg_str(msg.str());
62+
throw_domain_error(function, name, 1.0, msg_str.c_str());
63+
}();
64+
}
65+
}
66+
}
67+
68+
/**
69+
* Throw an exception if the specified matrices in a standard vector are not a column stochastic matrix. To be a
70+
* column stochastic matrix, all the values in each column must be greater than or equal to 0 and the values must sum to 1. A
71+
* valid column stochastic matrix is one where the sum of the elements by column is equal to 1. This
72+
* function tests that the sum is within the tolerance specified by
73+
* `CONSTRAINT_TOLERANCE`. This function only accepts Eigen matrices, statically
74+
* typed vectors, not general matrices with 1 column.
75+
* @tparam T A type inheriting from `Eigen::EigenBase`
76+
* @param function Function name (for error messages)
77+
* @param name Variable name (for error messages)
78+
* @param theta Matrix to test
79+
* @throw `std::invalid_argument` if `theta` is a 0-vector
80+
* @throw `std::domain_error` if the vector's matrices are not column stochastic matrices or if any element
81+
* is `NaN`
82+
*/
83+
template <typename T, require_std_vector_t<T>* = nullptr>
84+
void check_stochastic_column(const char* function, const char* name, const T& theta) {
85+
for (size_t i = 0; i < theta.size(); ++i) {
86+
check_stochastic_column(function, internal::make_iter_name(name, i).c_str(),
87+
theta[i]);
88+
}
89+
}
90+
91+
} // namespace math
92+
} // namespace stan
93+
#endif
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
#ifndef STAN_MATH_PRIM_ERR_CHECK_STOCHASTIC_ROW_HPP
2+
#define STAN_MATH_PRIM_ERR_CHECK_STOCHASTIC_ROW_HPP
3+
4+
#include <stan/math/prim/fun/Eigen.hpp>
5+
#include <stan/math/prim/meta.hpp>
6+
#include <stan/math/prim/err/check_nonzero_size.hpp>
7+
#include <stan/math/prim/err/constraint_tolerance.hpp>
8+
#include <stan/math/prim/err/make_iter_name.hpp>
9+
#include <stan/math/prim/err/throw_domain_error.hpp>
10+
#include <stan/math/prim/fun/to_ref.hpp>
11+
#include <stan/math/prim/fun/value_of_rec.hpp>
12+
#include <sstream>
13+
#include <string>
14+
15+
namespace stan {
16+
namespace math {
17+
18+
/**
19+
* Throw an exception if the specified matrix is not a row stochastic matrix. To be a
20+
* row stochastic matrix, all the values in each row must be greater than or equal to 0 and the values must sum to 1. A
21+
* valid row stochastic matrix is one where the sum of the elements by row is equal to 1. This
22+
* function tests that the sum is within the tolerance specified by
23+
* `CONSTRAINT_TOLERANCE`. This function only accepts Eigen matrices, statically
24+
* typed vectors, not general matrices with 1 column.
25+
* @tparam T A type inheriting from `Eigen::EigenBase`
26+
* @param function Function name (for error messages)
27+
* @param name Variable name (for error messages)
28+
* @param theta Matrix to test
29+
* @throw `std::invalid_argument` if `theta` is a 0-vector
30+
* @throw `std::domain_error` if the vector is not a row stochastic matrix or if any element
31+
* is `NaN`
32+
*/
33+
template <typename T, require_matrix_t<T>* = nullptr>
34+
void check_stochastic_row(const char* function, const char* name, const T& theta) {
35+
using std::fabs;
36+
check_nonzero_size(function, name, theta);
37+
auto&& theta_ref = to_ref(value_of_rec(theta));
38+
for (Eigen::Index i = 0; i < theta_ref.rows(); ++i) {
39+
value_type_t<decltype(theta_ref)> vec_sum = 0.0;
40+
for (Eigen::Index j = 0; j < theta_ref.cols(); ++j) {
41+
if (!(theta_ref.coeff(i, j) >= 0)) {
42+
[&]() STAN_COLD_PATH {
43+
std::ostringstream msg;
44+
msg << "is not a valid row stochastic matrix. " << name << "["
45+
<< std::to_string(i + stan::error_index::value) <<
46+
", " << std::to_string(i + stan::error_index::value) << "]"
47+
<< " = ";
48+
std::string msg_str(msg.str());
49+
throw_domain_error(function, name, theta_ref.coeff(i, j), msg_str.c_str(),
50+
", but should be greater than or equal to 0");
51+
}();
52+
}
53+
vec_sum += theta_ref.coeff(i, j);
54+
}
55+
if (!(fabs(1.0 - vec_sum) <= CONSTRAINT_TOLERANCE)) {
56+
[&]() STAN_COLD_PATH {
57+
std::stringstream msg;
58+
msg << "is not a valid row stochastic matrix.";
59+
msg.precision(10);
60+
msg << " sum(" << name << "[" << std::to_string(i + 1) << ",:]) = " << vec_sum << ", but should be ";
61+
std::string msg_str(msg.str());
62+
throw_domain_error(function, name, 1.0, msg_str.c_str());
63+
}();
64+
}
65+
}
66+
}
67+
68+
/**
69+
* Throw an exception if the specified matrices in a standard vector are not a row stochastic matrix. To be a
70+
* row stochastic matrix, all the values in each row must be greater than or equal to 0 and the values must sum to 1. A
71+
* valid row stochastic matrix is one where the sum of the elements by row is equal to 1. This
72+
* function tests that the sum is within the tolerance specified by
73+
* `CONSTRAINT_TOLERANCE`. This function only accepts Eigen matrices, statically
74+
* typed vectors, not general matrices with 1 column.
75+
* @tparam T A type inheriting from `Eigen::EigenBase`
76+
* @param function Function name (for error messages)
77+
* @param name Variable name (for error messages)
78+
* @param theta Matrix to test
79+
* @throw `std::invalid_argument` if `theta` is a 0-vector
80+
* @throw `std::domain_error` if the standard vector's matrices are not row stochastic matrix or if any element
81+
* is `NaN`
82+
*/
83+
template <typename T, require_std_vector_t<T>* = nullptr>
84+
void check_stochastic_row(const char* function, const char* name, const T& theta) {
85+
for (size_t i = 0; i < theta.size(); ++i) {
86+
check_stochastic_row(function, internal::make_iter_name(name, i).c_str(),
87+
theta[i]);
88+
}
89+
}
90+
91+
} // namespace math
92+
} // namespace stan
93+
#endif
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
#include <stan/math/prim.hpp>
2+
#include <gtest/gtest.h>
3+
#include <test/unit/util.hpp>
4+
#include <limits>
5+
#include <string>
6+
7+
TEST(ErrorHandlingMatrix, checkStochasticColumn) {
8+
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> y_vec(2, 2);
9+
std::vector<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>> y{y_vec, y_vec, y_vec};
10+
for (auto& y_i : y) {
11+
y_i << 0.5, 0.5, 0.5, 0.5;
12+
}
13+
14+
EXPECT_NO_THROW(stan::math::check_stochastic_column("checkStochasticColumn", "y", y));
15+
16+
for (auto& y_i : y) {
17+
y_i(0, 1) = 0.55;
18+
}
19+
EXPECT_THROW(stan::math::check_stochastic_column("checkStochasticColumn", "y", y),
20+
std::domain_error);
21+
}
22+
23+
TEST(ErrorHandlingMatrix, checkStochasticColumn_message_negative_value) {
24+
std::string message;
25+
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> y_vec(3, 3);
26+
y_vec.setZero();
27+
std::vector<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>> y{y_vec, y_vec, y_vec};
28+
for (auto& y_i : y) {
29+
y_i(0, 0) = -0.1;
30+
y_i(1, 0) = 1.1;
31+
y_i(0, 1) = -0.1;
32+
y_i(1, 1) = 1.1;
33+
}
34+
35+
try {
36+
stan::math::check_stochastic_column("checkStochasticColumn", "y", y);
37+
FAIL() << "should have thrown";
38+
} catch (std::domain_error& e) {
39+
message = e.what();
40+
} catch (...) {
41+
FAIL() << "threw the wrong error";
42+
}
43+
44+
EXPECT_TRUE(std::string::npos != message.find(" y[1] is not a valid column stochastic matrix"))
45+
<< "Found: " << message;
46+
47+
EXPECT_TRUE(std::string::npos != message.find("y[1][1, 1] = -0.1")) << "Found: " << message;
48+
49+
for (auto& y_i : y) {
50+
y_i.setZero();
51+
y_i(0, 0) = 0.1;
52+
y_i(1, 0) = 0.1;
53+
y_i(2, 0) = 1.0;
54+
}
55+
try {
56+
stan::math::check_stochastic_column("checkStochasticColumn", "y", y);
57+
FAIL() << "should have thrown";
58+
} catch (std::domain_error& e) {
59+
message = e.what();
60+
} catch (...) {
61+
FAIL() << "threw the wrong error";
62+
}
63+
64+
EXPECT_TRUE(std::string::npos != message.find(" y[1] is not a valid column stochastic matrix"))
65+
<< "Found: " << message;
66+
67+
EXPECT_TRUE(std::string::npos != message.find("sum(y[1][:, 1]) = 1.2")) << "Found: " << message;
68+
}
69+
70+
TEST(ErrorHandlingMatrix, checkStochasticColumn_message_sum) {
71+
std::string message;
72+
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> y_vec(10, 10);
73+
y_vec.setZero();
74+
std::vector<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>> y{y_vec, y_vec, y_vec};
75+
for (auto& y_i : y) {
76+
y_i(3, 0) = 0.9;
77+
}
78+
79+
try {
80+
stan::math::check_stochastic_column("checkStochasticColumn", "y", y);
81+
FAIL() << "should have thrown";
82+
} catch (std::domain_error& e) {
83+
message = e.what();
84+
} catch (...) {
85+
FAIL() << "threw the wrong error";
86+
}
87+
88+
EXPECT_TRUE(std::string::npos != message.find(" y[1] is not a valid column stochastic matrix"))
89+
<< message;
90+
91+
EXPECT_TRUE(std::string::npos != message.find("sum(y[1][:, 1]) = 0.9")) << message;
92+
}
93+
94+
TEST(ErrorHandlingMatrix, checkStochasticColumn_message_length) {
95+
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> y_vec(0, 0);
96+
std::vector<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>> y{y_vec, y_vec, y_vec};
97+
98+
using stan::math::check_stochastic_column;
99+
100+
EXPECT_THROW_MSG(check_stochastic_column("checkStochasticColumn", "y", y), std::invalid_argument,
101+
"y[1] has size 0, but must have a non-zero size");
102+
}
103+
104+
TEST(ErrorHandlingMatrix, checkStochasticColumn_nan) {
105+
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> y_vec(2, 2);
106+
std::vector<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>> y{y_vec, y_vec, y_vec};
107+
constexpr double nan = std::numeric_limits<double>::quiet_NaN();
108+
for (auto& y_i : y) {
109+
y_i << nan, 0.5, nan, 0.5;
110+
}
111+
112+
EXPECT_THROW(stan::math::check_stochastic_column("checkStochasticColumn", "y", y),
113+
std::domain_error);
114+
115+
for (auto& y_i : y) {
116+
y_i(0, 1) = 0.55;
117+
}
118+
EXPECT_THROW(stan::math::check_stochastic_column("checkStochasticColumn", "y", y),
119+
std::domain_error);
120+
121+
for (auto& y_i : y) {
122+
y_i(0, 0) = 0.5;
123+
y_i(0, 1) = nan;
124+
}
125+
EXPECT_THROW(stan::math::check_stochastic_column("checkStochasticColumn", "y", y),
126+
std::domain_error);
127+
128+
for (auto& y_i : y) {
129+
y_i(0, 0) = nan;
130+
}
131+
EXPECT_THROW(stan::math::check_stochastic_column("checkStochasticColumn", "y", y),
132+
std::domain_error);
133+
}
134+

0 commit comments

Comments
 (0)