forked from learning-process/parallel_programming_course
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrunner.cpp
More file actions
93 lines (73 loc) · 2.54 KB
/
runner.cpp
File metadata and controls
93 lines (73 loc) · 2.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#include <gtest/gtest.h>
#include <mpi.h>
#include <cstdio>
#include <cstdlib>
#include <memory>
#include <string>
#include <utility>
#include "core/util/include/util.hpp"
#include "oneapi/tbb/global_control.h"
class UnreadMessagesDetector : public ::testing::EmptyTestEventListener {
public:
UnreadMessagesDetector() = default;
void OnTestEnd(const ::testing::TestInfo& /*test_info*/) override {
int rank = -1;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Barrier(MPI_COMM_WORLD);
int flag = -1;
MPI_Status status;
MPI_Iprobe(MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &flag, &status);
if (flag != 0) {
fprintf(
stderr,
"[ PROCESS %d ] [ FAILED ] %s.%s: MPI message queue has an unread message from process %d with tag %d\n",
rank, "test_suite_name", "test_name", status.MPI_SOURCE, status.MPI_TAG);
MPI_Finalize();
std::abort();
}
MPI_Barrier(MPI_COMM_WORLD);
}
private:
};
class WorkerTestFailurePrinter : public ::testing::EmptyTestEventListener {
public:
explicit WorkerTestFailurePrinter(std::shared_ptr<::testing::TestEventListener> base) : base_(std::move(base)) {}
void OnTestEnd(const ::testing::TestInfo& test_info) override {
if (test_info.result()->Passed()) {
return;
}
PrintProcessRank();
base_->OnTestEnd(test_info);
}
void OnTestPartResult(const ::testing::TestPartResult& test_part_result) override {
if (test_part_result.passed() || test_part_result.skipped()) {
return;
}
PrintProcessRank();
base_->OnTestPartResult(test_part_result);
}
private:
static void PrintProcessRank() {
int rank = -1;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
printf(" [ PROCESS %d ] ", rank);
}
std::shared_ptr<::testing::TestEventListener> base_;
};
int main(int argc, char** argv) {
MPI_Init(&argc, &argv);
// Limit the number of threads in TBB
tbb::global_control control(tbb::global_control::max_allowed_parallelism, ppc::util::GetPPCNumThreads());
::testing::InitGoogleTest(&argc, argv);
auto& listeners = ::testing::UnitTest::GetInstance()->listeners();
int rank = -1;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
if (rank != 0 && (argc < 2 || argv[1] != std::string("--print-workers"))) {
auto* listener = listeners.Release(listeners.default_result_printer());
listeners.Append(new WorkerTestFailurePrinter(std::shared_ptr<::testing::TestEventListener>(listener)));
}
listeners.Append(new UnreadMessagesDetector());
auto status = RUN_ALL_TESTS();
MPI_Finalize();
return status;
}