|
11 | 11 | #include "core/task/include/task.hpp" |
12 | 12 | #include "mpi/example/include/ops_mpi.hpp" |
13 | 13 |
|
14 | | -TEST(nesterov_a_test_task_mpi, test_pipeline_run) { |
15 | | - constexpr int kCount = 500; |
16 | | - |
17 | | - // Create data |
18 | | - std::vector<int> in(kCount * kCount, 0); |
19 | | - |
20 | | - for (size_t i = 0; i < kCount; i++) { |
21 | | - in[(i * kCount) + i] = 1; |
22 | | - } |
23 | | - |
24 | | - // Create Task |
25 | | - auto test_task_mpi = std::make_shared<nesterov_a_test_task_mpi::TestTaskMPI>(in); |
26 | | - |
27 | | - // Create Perf attributes |
28 | | - ppc::core::PerfAttr perf_attr; |
29 | | - const auto t0 = std::chrono::high_resolution_clock::now(); |
30 | | - perf_attr.current_timer = [&] { |
31 | | - auto current_time_point = std::chrono::high_resolution_clock::now(); |
32 | | - auto duration = std::chrono::duration_cast<std::chrono::nanoseconds>(current_time_point - t0).count(); |
33 | | - return static_cast<double>(duration) * 1e-9; |
34 | | - }; |
35 | | - |
36 | | - // Create and init perf results |
37 | | - ppc::core::PerfResults perf_results; |
38 | | - |
39 | | - ppc::core::Perf perf_analyzer(test_task_mpi); |
40 | | - perf_analyzer.PipelineRun(perf_attr, perf_results); |
41 | | - // Create Perf analyzer |
42 | | - int rank = -1; |
43 | | - MPI_Comm_rank(MPI_COMM_WORLD, &rank); |
44 | | - if (rank == 0) { |
45 | | - ppc::core::Perf::PrintPerfStatistic(perf_results); |
46 | | - } |
47 | | - |
48 | | - ASSERT_EQ(in, test_task_mpi->Get()); |
49 | | -} |
50 | | - |
51 | | -TEST(nesterov_a_test_task_mpi, test_task_run) { |
52 | | - constexpr int kCount = 500; |
53 | | - |
54 | | - // Create data |
55 | | - std::vector<int> in(kCount * kCount, 0); |
56 | | - |
57 | | - for (size_t i = 0; i < kCount; i++) { |
58 | | - in[(i * kCount) + i] = 1; |
| 14 | +class NesterovATaskMPITest : public ::testing::TestWithParam<ppc::core::PerfResults::TypeOfRunning> { |
| 15 | + protected: |
| 16 | + static void RunTest(ppc::core::PerfResults::TypeOfRunning mode) { |
| 17 | + constexpr int kCount = 500; |
| 18 | + |
| 19 | + // Create data |
| 20 | + std::vector<int> in(kCount * kCount, 0); |
| 21 | + for (size_t i = 0; i < kCount; i++) { |
| 22 | + in[(i * kCount) + i] = 1; |
| 23 | + } |
| 24 | + |
| 25 | + // Create Task |
| 26 | + auto test_task_mpi = std::make_shared<nesterov_a_test_task_mpi::TestTaskMPI>(in); |
| 27 | + |
| 28 | + // Create Perf attributes |
| 29 | + ppc::core::PerfAttr perf_attr; |
| 30 | + const auto t0 = std::chrono::high_resolution_clock::now(); |
| 31 | + perf_attr.current_timer = [&] { |
| 32 | + auto current_time_point = std::chrono::high_resolution_clock::now(); |
| 33 | + auto duration = std::chrono::duration_cast<std::chrono::nanoseconds>(current_time_point - t0).count(); |
| 34 | + return static_cast<double>(duration) * 1e-9; |
| 35 | + }; |
| 36 | + |
| 37 | + // Create and init perf results |
| 38 | + ppc::core::PerfResults perf_results; |
| 39 | + |
| 40 | + // Create Perf analyzer |
| 41 | + ppc::core::Perf perf_analyzer(test_task_mpi); |
| 42 | + |
| 43 | + if (mode == ppc::core::PerfResults::TypeOfRunning::kPipeline) { |
| 44 | + perf_analyzer.PipelineRun(perf_attr, perf_results); |
| 45 | + } else { |
| 46 | + perf_analyzer.TaskRun(perf_attr, perf_results); |
| 47 | + } |
| 48 | + |
| 49 | + int rank = -1; |
| 50 | + MPI_Comm_rank(MPI_COMM_WORLD, &rank); |
| 51 | + if (rank == 0) { |
| 52 | + ppc::core::Perf::PrintPerfStatistic(perf_results); |
| 53 | + } |
| 54 | + |
| 55 | + ASSERT_EQ(in, test_task_mpi->Get()); |
59 | 56 | } |
| 57 | +}; |
60 | 58 |
|
61 | | - // Create Task |
62 | | - auto test_task_mpi = std::make_shared<nesterov_a_test_task_mpi::TestTaskMPI>(in); |
63 | | - |
64 | | - // Create Perf attributes |
65 | | - ppc::core::PerfAttr perf_attr; |
66 | | - const auto t0 = std::chrono::high_resolution_clock::now(); |
67 | | - perf_attr.current_timer = [&] { |
68 | | - auto current_time_point = std::chrono::high_resolution_clock::now(); |
69 | | - auto duration = std::chrono::duration_cast<std::chrono::nanoseconds>(current_time_point - t0).count(); |
70 | | - return static_cast<double>(duration) * 1e-9; |
71 | | - }; |
72 | | - |
73 | | - // Create and init perf results |
74 | | - ppc::core::PerfResults perf_results; |
75 | | - |
76 | | - // Create Perf analyzer |
77 | | - ppc::core::Perf perf_analyzer(test_task_mpi); |
78 | | - perf_analyzer.TaskRun(perf_attr, perf_results); |
79 | | - // Create Perf analyzer |
80 | | - int rank = -1; |
81 | | - MPI_Comm_rank(MPI_COMM_WORLD, &rank); |
82 | | - if (rank == 0) { |
83 | | - ppc::core::Perf::PrintPerfStatistic(perf_results); |
84 | | - } |
| 59 | +TEST_P(NesterovATaskMPITest, RunModes) { RunTest(GetParam()); } |
85 | 60 |
|
86 | | - ASSERT_EQ(in, test_task_mpi->Get()); |
87 | | -} |
| 61 | +INSTANTIATE_TEST_SUITE_P(NesterovATests, NesterovATaskMPITest, |
| 62 | + ::testing::Values(ppc::core::PerfResults::TypeOfRunning::kPipeline, |
| 63 | + ppc::core::PerfResults::TypeOfRunning::kTaskRun)); |
0 commit comments