Skip to content

Commit 9d07107

Browse files
committed
Refactor functional and performance runners by extracting MPI-related logic into ppc::core::Init for improved modularity.
1 parent 2604844 commit 9d07107

4 files changed

Lines changed: 119 additions & 184 deletions

File tree

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#pragma once
2+
3+
#include <gtest/gtest.h>
4+
5+
namespace ppc::core {
6+
7+
class UnreadMessagesDetector : public ::testing::EmptyTestEventListener {
8+
public:
9+
UnreadMessagesDetector() = default;
10+
void OnTestEnd(const ::testing::TestInfo& /*test_info*/) override;
11+
12+
private:
13+
};
14+
15+
class WorkerTestFailurePrinter : public ::testing::EmptyTestEventListener {
16+
public:
17+
explicit WorkerTestFailurePrinter(std::shared_ptr<::testing::TestEventListener> base) : base_(std::move(base)) {}
18+
void OnTestEnd(const ::testing::TestInfo& test_info) override;
19+
void OnTestPartResult(const ::testing::TestPartResult& test_part_result) override;
20+
21+
private:
22+
static void PrintProcessRank();
23+
std::shared_ptr<::testing::TestEventListener> base_;
24+
};
25+
26+
int Init(int argc, char** argv);
27+
28+
} // namespace ppc::core
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#include "core/runners/include/runners.hpp"
2+
3+
#include <gtest/gtest.h>
4+
#include <mpi.h>
5+
#include <omp.h>
6+
7+
#include <cstdio>
8+
#include <cstdlib>
9+
#include <format>
10+
#include <iostream>
11+
#include <memory>
12+
#include <string>
13+
#include <utility>
14+
15+
#include "core/util/include/util.hpp"
16+
#include "oneapi/tbb/global_control.h"
17+
18+
namespace ppc::core {
19+
20+
void UnreadMessagesDetector::OnTestEnd(const ::testing::TestInfo& /*test_info*/) {
21+
int rank = -1;
22+
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
23+
24+
MPI_Barrier(MPI_COMM_WORLD);
25+
26+
int flag = -1;
27+
MPI_Status status;
28+
29+
MPI_Iprobe(MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &flag, &status);
30+
31+
if (flag != 0) {
32+
std::cerr
33+
<< std::format(
34+
"[ PROCESS {} ] [ FAILED ] MPI message queue has an unread message from process {} with tag {}",
35+
rank, status.MPI_SOURCE, status.MPI_TAG)
36+
<< '\n';
37+
MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
38+
}
39+
40+
MPI_Barrier(MPI_COMM_WORLD);
41+
}
42+
43+
void WorkerTestFailurePrinter::OnTestEnd(const ::testing::TestInfo& test_info) {
44+
if (test_info.result()->Passed()) {
45+
return;
46+
}
47+
PrintProcessRank();
48+
base_->OnTestEnd(test_info);
49+
}
50+
51+
void WorkerTestFailurePrinter::OnTestPartResult(const ::testing::TestPartResult& test_part_result) {
52+
if (test_part_result.passed() || test_part_result.skipped()) {
53+
return;
54+
}
55+
PrintProcessRank();
56+
base_->OnTestPartResult(test_part_result);
57+
}
58+
59+
void WorkerTestFailurePrinter::PrintProcessRank() {
60+
int rank = -1;
61+
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
62+
std::cerr << std::format(" [ PROCESS {} ] ", rank);
63+
}
64+
65+
int Init(int argc, char** argv) {
66+
MPI_Init(&argc, &argv);
67+
68+
// Limit the number of threads in TBB
69+
tbb::global_control control(tbb::global_control::max_allowed_parallelism, ppc::util::GetNumThreads());
70+
71+
::testing::InitGoogleTest(&argc, argv);
72+
73+
auto& listeners = ::testing::UnitTest::GetInstance()->listeners();
74+
int rank = -1;
75+
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
76+
if (rank != 0 && (argc < 2 || argv[1] != std::string("--print-workers"))) {
77+
auto* listener = listeners.Release(listeners.default_result_printer());
78+
listeners.Append(new ppc::core::WorkerTestFailurePrinter(std::shared_ptr<::testing::TestEventListener>(listener)));
79+
}
80+
listeners.Append(new ppc::core::UnreadMessagesDetector());
81+
auto status = RUN_ALL_TESTS();
82+
83+
MPI_Finalize();
84+
return status;
85+
}
86+
87+
} // namespace ppc::core

tasks/common/runners/functional.cpp

Lines changed: 2 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,99 +1,12 @@
11
#include <gtest/gtest.h>
2-
#include <mpi.h>
3-
#include <omp.h>
4-
5-
#include <cstdio>
6-
#include <cstdlib>
7-
#include <format>
8-
#include <iostream>
9-
#include <memory>
10-
#include <string>
11-
#include <utility>
122

3+
#include "core/runners/include/runners.hpp"
134
#include "core/util/include/util.hpp"
145
#include "oneapi/tbb/global_control.h"
156

16-
class UnreadMessagesDetector : public ::testing::EmptyTestEventListener {
17-
public:
18-
UnreadMessagesDetector() = default;
19-
20-
void OnTestEnd(const ::testing::TestInfo& /*test_info*/) override {
21-
int rank = -1;
22-
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
23-
24-
MPI_Barrier(MPI_COMM_WORLD);
25-
26-
int flag = -1;
27-
MPI_Status status;
28-
29-
MPI_Iprobe(MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &flag, &status);
30-
31-
if (flag != 0) {
32-
std::cerr << std::format(
33-
"[ PROCESS {} ] [ FAILED ] {}.{}: MPI message queue has an unread message from process {} "
34-
"with tag {}",
35-
rank, "test_suite_name", "test_name", status.MPI_SOURCE, status.MPI_TAG)
36-
<< '\n';
37-
MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
38-
}
39-
40-
MPI_Barrier(MPI_COMM_WORLD);
41-
}
42-
43-
private:
44-
};
45-
46-
class WorkerTestFailurePrinter : public ::testing::EmptyTestEventListener {
47-
public:
48-
explicit WorkerTestFailurePrinter(std::shared_ptr<::testing::TestEventListener> base) : base_(std::move(base)) {}
49-
50-
void OnTestEnd(const ::testing::TestInfo& test_info) override {
51-
if (test_info.result()->Passed()) {
52-
return;
53-
}
54-
PrintProcessRank();
55-
base_->OnTestEnd(test_info);
56-
}
57-
58-
void OnTestPartResult(const ::testing::TestPartResult& test_part_result) override {
59-
if (test_part_result.passed() || test_part_result.skipped()) {
60-
return;
61-
}
62-
PrintProcessRank();
63-
base_->OnTestPartResult(test_part_result);
64-
}
65-
66-
private:
67-
static void PrintProcessRank() {
68-
int rank = -1;
69-
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
70-
std::cerr << std::format(" [ PROCESS {} ] ", rank);
71-
}
72-
73-
std::shared_ptr<::testing::TestEventListener> base_;
74-
};
75-
767
int main(int argc, char** argv) {
778
if (ppc::util::IsUnderMpirun()) {
78-
MPI_Init(&argc, &argv);
79-
80-
// Limit the number of threads in TBB
81-
tbb::global_control control(tbb::global_control::max_allowed_parallelism, ppc::util::GetNumThreads());
82-
83-
::testing::InitGoogleTest(&argc, argv);
84-
85-
auto& listeners = ::testing::UnitTest::GetInstance()->listeners();
86-
int rank = -1;
87-
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
88-
if (rank != 0 && (argc < 2 || argv[1] != std::string("--print-workers"))) {
89-
auto* listener = listeners.Release(listeners.default_result_printer());
90-
listeners.Append(new WorkerTestFailurePrinter(std::shared_ptr<::testing::TestEventListener>(listener)));
91-
}
92-
listeners.Append(new UnreadMessagesDetector());
93-
auto status = RUN_ALL_TESTS();
94-
95-
MPI_Finalize();
96-
return status;
9+
return ppc::core::Init(argc, argv);
9710
}
9811

9912
// Limit the number of threads in TBB
Lines changed: 2 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,96 +1,3 @@
1-
#include <gtest/gtest.h>
2-
#include <mpi.h>
3-
#include <omp.h>
1+
#include "core/runners/include/runners.hpp"
42

5-
#include <cstdio>
6-
#include <cstdlib>
7-
#include <format>
8-
#include <iostream>
9-
#include <memory>
10-
#include <string>
11-
#include <utility>
12-
13-
#include "core/util/include/util.hpp"
14-
#include "oneapi/tbb/global_control.h"
15-
16-
class UnreadMessagesDetector : public ::testing::EmptyTestEventListener {
17-
public:
18-
UnreadMessagesDetector() = default;
19-
20-
void OnTestEnd(const ::testing::TestInfo& /*test_info*/) override {
21-
int rank = -1;
22-
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
23-
24-
MPI_Barrier(MPI_COMM_WORLD);
25-
26-
int flag = -1;
27-
MPI_Status status;
28-
29-
MPI_Iprobe(MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &flag, &status);
30-
31-
if (flag != 0) {
32-
std::cerr
33-
<< std::format(
34-
"[ PROCESS {} ] [ FAILED ] MPI message queue has an unread message from process {} with tag {}",
35-
rank, status.MPI_SOURCE, status.MPI_TAG)
36-
<< '\n';
37-
MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
38-
}
39-
40-
MPI_Barrier(MPI_COMM_WORLD);
41-
}
42-
43-
private:
44-
};
45-
46-
class WorkerTestFailurePrinter : public ::testing::EmptyTestEventListener {
47-
public:
48-
explicit WorkerTestFailurePrinter(std::shared_ptr<::testing::TestEventListener> base) : base_(std::move(base)) {}
49-
50-
void OnTestEnd(const ::testing::TestInfo& test_info) override {
51-
if (test_info.result()->Passed()) {
52-
return;
53-
}
54-
PrintProcessRank();
55-
base_->OnTestEnd(test_info);
56-
}
57-
58-
void OnTestPartResult(const ::testing::TestPartResult& test_part_result) override {
59-
if (test_part_result.passed() || test_part_result.skipped()) {
60-
return;
61-
}
62-
PrintProcessRank();
63-
base_->OnTestPartResult(test_part_result);
64-
}
65-
66-
private:
67-
static void PrintProcessRank() {
68-
int rank = -1;
69-
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
70-
std::cerr << std::format(" [ PROCESS {} ] ", rank);
71-
}
72-
73-
std::shared_ptr<::testing::TestEventListener> base_;
74-
};
75-
76-
int main(int argc, char** argv) {
77-
MPI_Init(&argc, &argv);
78-
79-
// Limit the number of threads in TBB
80-
tbb::global_control control(tbb::global_control::max_allowed_parallelism, ppc::util::GetNumThreads());
81-
82-
::testing::InitGoogleTest(&argc, argv);
83-
84-
auto& listeners = ::testing::UnitTest::GetInstance()->listeners();
85-
int rank = -1;
86-
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
87-
if (rank != 0 && (argc < 2 || argv[1] != std::string("--print-workers"))) {
88-
auto* listener = listeners.Release(listeners.default_result_printer());
89-
listeners.Append(new WorkerTestFailurePrinter(std::shared_ptr<::testing::TestEventListener>(listener)));
90-
}
91-
listeners.Append(new UnreadMessagesDetector());
92-
auto status = RUN_ALL_TESTS();
93-
94-
MPI_Finalize();
95-
return status;
96-
}
3+
int main(int argc, char** argv) { return ppc::core::Init(argc, argv); }

0 commit comments

Comments
 (0)