Skip to content

Commit 56f71d7

Browse files
committed
Fixes for failing test
1 parent 6c6063d commit 56f71d7

1 file changed

Lines changed: 21 additions & 6 deletions

File tree

stan/math/prim/fun/poisson_binomial_log_probs.hpp

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,24 +24,24 @@ namespace math {
2424
*/
2525
template <typename T_theta, typename T_scalar = scalar_type_t<T_theta>,
2626
require_vector_t<T_theta>* = nullptr>
27-
Eigen::Matrix<T_scalar, 1, -1>
27+
Eigen::Matrix<T_scalar, -1, 1>
2828
poisson_binomial_log_probs(int y, const T_theta& theta) {
2929
int size_theta = theta.size();
3030
plain_type_t<T_theta> log_theta = log(theta);
3131
plain_type_t<T_theta> log1m_theta = log1m(theta);
3232

33-
Eigen::Matrix<T_scalar, Eigen::Dynamic, Eigen::Dynamic> alpha(size_theta + 1,
34-
y + 1);
33+
Eigen::Matrix<T_scalar, Eigen::Dynamic, Eigen::Dynamic> alpha(y + 1,
34+
size_theta + 1);
3535

3636
// alpha[i, j] = log prob of j successes in first i trials
3737
alpha(0, 0) = 0.0;
3838
for (int i = 0; i < size_theta; ++i) {
3939
// no success in i trials
40-
alpha(i + 1, 0) = alpha(i, 0) + log1m_theta[i];
40+
alpha(0, i + 1) = alpha(0, i) + log1m_theta[i];
4141

4242
// 0 < j < i successes in i trials
4343
for (int j = 0; j < std::min(y, i); ++j) {
44-
alpha(i + 1, j + 1) = log_mix(theta[i], alpha(i, j), alpha(i, j + 1));
44+
alpha(j + 1, i + 1) = log_mix(theta[i], alpha(j, i), alpha(j + 1, i));
4545
}
4646

4747
// i successes in i trials
@@ -50,7 +50,22 @@ Eigen::Matrix<T_scalar, 1, -1>
5050
}
5151
}
5252

53-
return alpha.row(size_theta);
53+
return alpha.col(size_theta);
54+
}
55+
56+
template <typename T_y, typename T_theta, require_vt_integral<T_y>* = nullptr>
57+
auto poisson_binomial_log_probs(const T_y& y, const T_theta& theta) {
58+
using T_scalar = scalar_type_t<T_theta>;
59+
size_t max_sizes = std::max(stan::math::size(y), size_mvt(theta));
60+
std::vector<Eigen::Matrix<T_scalar, Eigen::Dynamic, 1>> result(max_sizes);
61+
scalar_seq_view<T_y> y_vec(y);
62+
vector_seq_view<T_theta> theta_vec(theta);
63+
64+
for (size_t i = 0; i < max_sizes; ++i) {
65+
result[i] = poisson_binomial_log_probs(y_vec[i], theta_vec[i]);
66+
}
67+
68+
return result;
5469
}
5570

5671
} // namespace math

0 commit comments

Comments
 (0)