forked from learning-process/parallel_programming_course
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfunc_test_util.hpp
More file actions
143 lines (115 loc) · 5.25 KB
/
func_test_util.hpp
File metadata and controls
143 lines (115 loc) · 5.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#pragma once
#include <gtest/gtest.h>
#include <tbb/tick_count.h>
#include <concepts>
#include <csignal>
#include <cstddef>
#include <functional>
#include <iostream>
#include <string>
#include <tuple>
#include <type_traits>
#include <utility>
#include "task/include/task.hpp"
#include "util/include/util.hpp"
namespace ppc::util {
template <typename InType, typename OutType, typename TestType = void>
using FuncTestParam = std::tuple<std::function<ppc::task::TaskPtr<InType, OutType>(InType)>, std::string, TestType>;
template <typename InType, typename OutType, typename TestType = void>
using GTestFuncParam = ::testing::TestParamInfo<FuncTestParam<InType, OutType, TestType>>;
template <typename T, typename TestType>
concept HasPrintTestParam = requires(TestType value) {
{ T::PrintTestParam(value) } -> std::same_as<std::string>;
};
template <typename InType, typename OutType, typename TestType = void>
/// @brief Base class for running functional tests on parallel tasks.
/// @tparam InType Type of input data.
/// @tparam OutType Type of output data.
/// @tparam TestType Type of the test case or parameter.
class BaseRunFuncTests : public ::testing::TestWithParam<FuncTestParam<InType, OutType, TestType>> {
public:
virtual bool CheckTestOutputData(OutType& output_data) = 0;
/// @brief Provides input data for the task.
/// @return Initialized input data.
virtual InType GetTestInputData() = 0;
template <typename Derived>
static void RequireStaticInterface() {
static_assert(HasPrintTestParam<Derived, TestType>,
"Derived class must implement: static std::string PrintTestParam(TestType)");
}
template <typename Derived>
static std::string PrintFuncTestName(const GTestFuncParam<InType, OutType, TestType>& info) {
RequireStaticInterface<Derived>();
TestType test_param = std::get<static_cast<std::size_t>(ppc::util::GTestParamIndex::kTestParams)>(info.param);
return std::get<static_cast<std::size_t>(GTestParamIndex::kNameTest)>(info.param) + "_" +
Derived::PrintTestParam(test_param);
}
protected:
void ExecuteTest(FuncTestParam<InType, OutType, TestType> test_param) {
const std::string& test_name = std::get<static_cast<std::size_t>(GTestParamIndex::kNameTest)>(test_param);
ValidateTestName(test_name);
if (IsTestDisabled(test_name)) {
GTEST_SKIP();
}
if (ShouldSkipNonMpiTask(test_name)) {
std::cerr << "kALL and kMPI tasks are not under mpirun\n";
GTEST_SKIP();
}
InitializeAndRunTask(test_param);
}
void ValidateTestName(const std::string& test_name) {
EXPECT_FALSE(test_name.find("unknown") != std::string::npos);
}
bool IsTestDisabled(const std::string& test_name) {
return test_name.find("disabled") != std::string::npos;
}
bool ShouldSkipNonMpiTask(const std::string& test_name) {
auto contains_substring = [&](const std::string& substring) {
return test_name.find(substring) != std::string::npos;
};
return !ppc::util::IsUnderMpirun() && (contains_substring("_all") || contains_substring("_mpi"));
}
/// @brief Initializes task instance and runs it through the full pipeline.
void InitializeAndRunTask(const FuncTestParam<InType, OutType, TestType>& test_param) {
task_ = std::get<static_cast<std::size_t>(GTestParamIndex::kTaskGetter)>(test_param)(GetTestInputData());
ExecuteTaskPipeline();
}
/// @brief Executes the full task pipeline with validation.
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
void ExecuteTaskPipeline() {
EXPECT_TRUE(task_->Validation());
EXPECT_TRUE(task_->PreProcessing());
EXPECT_TRUE(task_->Run());
EXPECT_TRUE(task_->PostProcessing());
EXPECT_TRUE(CheckTestOutputData(task_->GetOutput()));
}
private:
ppc::task::TaskPtr<InType, OutType> task_;
};
template <typename Tuple, std::size_t... Is>
auto ExpandToValuesImpl(const Tuple& t, std::index_sequence<Is...> /*unused*/) {
return ::testing::Values(std::get<Is>(t)...);
}
template <typename Tuple>
auto ExpandToValues(const Tuple& t) {
constexpr std::size_t kN = std::tuple_size_v<Tuple>;
return ExpandToValuesImpl(t, std::make_index_sequence<kN>{});
}
template <typename Task, typename InType, typename SizesContainer, std::size_t... Is>
auto GenTaskTuplesImpl(const SizesContainer& sizes, const std::string& settings_path,
std::index_sequence<Is...> /*unused*/) {
return std::make_tuple(std::make_tuple(ppc::task::TaskGetter<Task, InType>,
std::string(GetNamespace<Task>()) + "_" +
ppc::task::GetStringTaskType(Task::GetStaticTypeOfTask(), settings_path),
sizes[Is])...);
}
template <typename Task, typename InType, typename SizesContainer>
auto TaskListGenerator(const SizesContainer& sizes, const std::string& settings_path) {
return GenTaskTuplesImpl<Task, InType>(sizes, settings_path,
std::make_index_sequence<std::tuple_size_v<std::decay_t<SizesContainer>>>{});
}
template <typename Task, typename InType, typename SizesContainer>
constexpr auto AddFuncTask(const SizesContainer& sizes, const std::string& settings_path) {
return TaskListGenerator<Task, InType>(sizes, settings_path);
}
} // namespace ppc::util