|
1 | 1 | #include <gtest/gtest.h> |
2 | | -#include <mpi.h> |
3 | | -#include <omp.h> |
4 | | - |
5 | | -#include <cstdio> |
6 | | -#include <cstdlib> |
7 | | -#include <format> |
8 | | -#include <iostream> |
9 | | -#include <memory> |
10 | | -#include <string> |
11 | | -#include <utility> |
12 | 2 |
|
| 3 | +#include "core/runners/include/runners.hpp" |
13 | 4 | #include "core/util/include/util.hpp" |
14 | 5 | #include "oneapi/tbb/global_control.h" |
15 | 6 |
|
16 | | -class UnreadMessagesDetector : public ::testing::EmptyTestEventListener { |
17 | | - public: |
18 | | - UnreadMessagesDetector() = default; |
19 | | - |
20 | | - void OnTestEnd(const ::testing::TestInfo& /*test_info*/) override { |
21 | | - int rank = -1; |
22 | | - MPI_Comm_rank(MPI_COMM_WORLD, &rank); |
23 | | - |
24 | | - MPI_Barrier(MPI_COMM_WORLD); |
25 | | - |
26 | | - int flag = -1; |
27 | | - MPI_Status status; |
28 | | - |
29 | | - MPI_Iprobe(MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &flag, &status); |
30 | | - |
31 | | - if (flag != 0) { |
32 | | - std::cerr << std::format( |
33 | | - "[ PROCESS {} ] [ FAILED ] {}.{}: MPI message queue has an unread message from process {} " |
34 | | - "with tag {}", |
35 | | - rank, "test_suite_name", "test_name", status.MPI_SOURCE, status.MPI_TAG) |
36 | | - << '\n'; |
37 | | - MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); |
38 | | - } |
39 | | - |
40 | | - MPI_Barrier(MPI_COMM_WORLD); |
41 | | - } |
42 | | - |
43 | | - private: |
44 | | -}; |
45 | | - |
46 | | -class WorkerTestFailurePrinter : public ::testing::EmptyTestEventListener { |
47 | | - public: |
48 | | - explicit WorkerTestFailurePrinter(std::shared_ptr<::testing::TestEventListener> base) : base_(std::move(base)) {} |
49 | | - |
50 | | - void OnTestEnd(const ::testing::TestInfo& test_info) override { |
51 | | - if (test_info.result()->Passed()) { |
52 | | - return; |
53 | | - } |
54 | | - PrintProcessRank(); |
55 | | - base_->OnTestEnd(test_info); |
56 | | - } |
57 | | - |
58 | | - void OnTestPartResult(const ::testing::TestPartResult& test_part_result) override { |
59 | | - if (test_part_result.passed() || test_part_result.skipped()) { |
60 | | - return; |
61 | | - } |
62 | | - PrintProcessRank(); |
63 | | - base_->OnTestPartResult(test_part_result); |
64 | | - } |
65 | | - |
66 | | - private: |
67 | | - static void PrintProcessRank() { |
68 | | - int rank = -1; |
69 | | - MPI_Comm_rank(MPI_COMM_WORLD, &rank); |
70 | | - std::cerr << std::format(" [ PROCESS {} ] ", rank); |
71 | | - } |
72 | | - |
73 | | - std::shared_ptr<::testing::TestEventListener> base_; |
74 | | -}; |
75 | | - |
76 | 7 | int main(int argc, char** argv) { |
77 | 8 | if (ppc::util::IsUnderMpirun()) { |
78 | | - MPI_Init(&argc, &argv); |
79 | | - |
80 | | - // Limit the number of threads in TBB |
81 | | - tbb::global_control control(tbb::global_control::max_allowed_parallelism, ppc::util::GetNumThreads()); |
82 | | - |
83 | | - ::testing::InitGoogleTest(&argc, argv); |
84 | | - |
85 | | - auto& listeners = ::testing::UnitTest::GetInstance()->listeners(); |
86 | | - int rank = -1; |
87 | | - MPI_Comm_rank(MPI_COMM_WORLD, &rank); |
88 | | - if (rank != 0 && (argc < 2 || argv[1] != std::string("--print-workers"))) { |
89 | | - auto* listener = listeners.Release(listeners.default_result_printer()); |
90 | | - listeners.Append(new WorkerTestFailurePrinter(std::shared_ptr<::testing::TestEventListener>(listener))); |
91 | | - } |
92 | | - listeners.Append(new UnreadMessagesDetector()); |
93 | | - auto status = RUN_ALL_TESTS(); |
94 | | - |
95 | | - MPI_Finalize(); |
96 | | - return status; |
| 9 | + return ppc::core::Init(argc, argv); |
97 | 10 | } |
98 | 11 |
|
99 | 12 | // Limit the number of threads in TBB |
|
0 commit comments