|
1 | 1 | #pragma once |
2 | 2 |
|
3 | 3 | #include <gtest/gtest.h> |
4 | | -#include <mpi.h> |
5 | 4 | #include <omp.h> |
6 | 5 | #include <tbb/tick_count.h> |
7 | 6 |
|
| 7 | +#include <concepts> |
8 | 8 | #include <csignal> |
9 | | -#include <filesystem> |
10 | | -#include <fstream> |
| 9 | +#include <functional> |
| 10 | +#include <string> |
| 11 | +#include <tuple> |
11 | 12 |
|
12 | | -#include "core/perf/include/perf.hpp" |
| 13 | +#include "core/task/include/task.hpp" |
13 | 14 | #include "core/util/include/util.hpp" |
14 | 15 |
|
15 | 16 | namespace ppc::util { |
@@ -44,31 +45,53 @@ class BaseRunFuncTests : public ::testing::TestWithParam<FuncTestParam<InType, O |
44 | 45 | return std::get<GTestParamIndex::kNameTest>(info.param) + "_" + Derived::PrintTestParam(test_param); |
45 | 46 | } |
46 | 47 |
|
47 | | - protected: |
48 | 48 | void ExecuteTest(FuncTestParam<InType, OutType, TestType> test_param) { |
49 | | - ASSERT_FALSE(std::get<GTestParamIndex::kNameTest>(test_param).find("unknown") != std::string::npos); |
50 | | - if (std::get<GTestParamIndex::kNameTest>(test_param).find("disabled") != std::string::npos) { |
| 49 | + const std::string& test_name = std::get<GTestParamIndex::kNameTest>(test_param); |
| 50 | + |
| 51 | + validateTestName(test_name); |
| 52 | + |
| 53 | + if (isTestDisabled(test_name)) { |
51 | 54 | GTEST_SKIP(); |
52 | 55 | } |
53 | 56 |
|
54 | | - auto which_task = [&](const std::string& substring) { |
55 | | - return std::get<GTestParamIndex::kNameTest>(test_param).find(substring) != std::string::npos; |
56 | | - }; |
57 | | - |
58 | | - if (!ppc::util::IsUnderMpirun() && (which_task("_all") || which_task("_mpi"))) { |
59 | | - std::cerr << "kALL and kMPI tasks are not under mpirun" << '\n'; |
| 57 | + if (shouldSkipNonMpiTask(test_name)) { |
| 58 | + std::cerr << "kALL and kMPI tasks are not under mpirun\n"; |
60 | 59 | GTEST_SKIP(); |
61 | 60 | } |
62 | 61 |
|
63 | | - task_ = std::get<GTestParamIndex::kTaskGetter>(test_param)(GetTestInputData()); |
64 | | - ASSERT_TRUE(task_->Validation()); |
65 | | - ASSERT_TRUE(task_->PreProcessing()); |
66 | | - ASSERT_TRUE(task_->Run()); |
67 | | - ASSERT_TRUE(task_->PostProcessing()); |
68 | | - ASSERT_TRUE(CheckTestOutputData(task_->GetOutput())); |
| 62 | + initializeAndRunTask(test_param); |
69 | 63 | } |
70 | 64 |
|
71 | 65 | private: |
| 66 | + static constexpr std::string UNKNOWN_TEST = "unknown"; |
| 67 | + static constexpr std::string DISABLED_TEST = "disabled"; |
| 68 | + static constexpr std::string ALL_TASK = "_all"; |
| 69 | + static constexpr std::string MPI_TASK = "_mpi"; |
| 70 | + |
| 71 | + void validateTestName(const std::string& test_name) { |
| 72 | + EXPECT_FALSE(test_name.find(UNKNOWN_TEST) != std::string::npos); |
| 73 | + } |
| 74 | + |
| 75 | + bool isTestDisabled(const std::string& test_name) { return test_name.find(DISABLED_TEST) != std::string::npos; } |
| 76 | + |
| 77 | + bool shouldSkipNonMpiTask(const std::string& test_name) { |
| 78 | + auto containsSubstring = [&](const std::string& substring) { |
| 79 | + return test_name.find(substring) != std::string::npos; |
| 80 | + }; |
| 81 | + |
| 82 | + return !ppc::util::IsUnderMpirun() && (containsSubstring(ALL_TASK) || containsSubstring(MPI_TASK)); |
| 83 | + } |
| 84 | + |
| 85 | + void initializeAndRunTask(const FuncTestParam<InType, OutType, TestType>& test_param) { |
| 86 | + task_ = std::get<GTestParamIndex::kTaskGetter>(test_param)(GetTestInputData()); |
| 87 | + |
| 88 | + EXPECT_TRUE(task_->Validation()); |
| 89 | + EXPECT_TRUE(task_->PreProcessing()); |
| 90 | + EXPECT_TRUE(task_->Run()); |
| 91 | + EXPECT_TRUE(task_->PostProcessing()); |
| 92 | + EXPECT_TRUE(CheckTestOutputData(task_->GetOutput())); |
| 93 | + } |
| 94 | + |
72 | 95 | ppc::core::TaskPtr<InType, OutType> task_; |
73 | 96 | }; |
74 | 97 |
|
@@ -98,10 +121,6 @@ auto ExpandToValues(const Tuple& t) { |
98 | 121 | return GenTaskTuplesImpl<Task>(std::make_index_sequence<SizesParam.size()>{}); \ |
99 | 122 | } |
100 | 123 |
|
101 | | -#if 1 |
102 | | -#define ADD_FUNC_TASK(TASK) TaskListGenerator<TASK>() |
103 | | -#else |
104 | | -#define ADD_FUNC_TASK(TASK) std::tuple<>() |
105 | | -#endif |
| 124 | +#define ADD_FUNC_TASK(TASK) TaskListGenerator<TASK>() // std::tuple<>() |
106 | 125 |
|
107 | 126 | } // namespace ppc::util |
0 commit comments