|
10 | 10 |
|
11 | 11 | #include <sycl/ext/oneapi/experimental/graph.hpp> |
12 | 12 | #include <sycl/sycl.hpp> |
| 13 | +#include <unordered_map> |
13 | 14 |
|
14 | 15 | namespace dpct { |
15 | 16 | namespace experimental { |
@@ -65,8 +66,42 @@ class graph_mgr { |
65 | 66 | (*graph)->end_recording(); |
66 | 67 | } |
67 | 68 |
|
| 69 | + void get_nodes(dpct::experimental::command_graph_ptr graph, |
| 70 | + dpct::experimental::node_ptr *nodesArray, |
| 71 | + std::size_t *numberOfNodes) { |
| 72 | + auto nodes = graph->get_nodes(); |
| 73 | + nodes_map[graph] = nodes; |
| 74 | + *numberOfNodes = nodes.size(); |
| 75 | + if (!nodesArray) { |
| 76 | + return; |
| 77 | + } |
| 78 | + for (std::size_t i = 0; i < *numberOfNodes; i++) { |
| 79 | + nodesArray[i] = &nodes_map[graph][i]; |
| 80 | + } |
| 81 | + } |
| 82 | + |
| 83 | + void get_root_nodes(dpct::experimental::command_graph_ptr graph, |
| 84 | + dpct::experimental::node_ptr *nodesArray, |
| 85 | + std::size_t *numberOfNodes) { |
| 86 | + auto root_nodes = graph->get_root_nodes(); |
| 87 | + root_nodes_map[graph] = root_nodes; |
| 88 | + *numberOfNodes = root_nodes.size(); |
| 89 | + if (!nodesArray) { |
| 90 | + return; |
| 91 | + } |
| 92 | + for (std::size_t i = 0; i < *numberOfNodes; i++) { |
| 93 | + nodesArray[i] = &root_nodes_map[graph][i]; |
| 94 | + } |
| 95 | + } |
| 96 | + |
68 | 97 | private: |
69 | 98 | std::unordered_map<sycl::queue *, command_graph_ptr> queue_graph_map; |
| 99 | + std::unordered_map<dpct::experimental::command_graph_ptr, |
| 100 | + std::vector<sycl::ext::oneapi::experimental::node>> |
| 101 | + nodes_map; |
| 102 | + std::unordered_map<dpct::experimental::command_graph_ptr, |
| 103 | + std::vector<sycl::ext::oneapi::experimental::node>> |
| 104 | + root_nodes_map; |
70 | 105 | }; |
71 | 106 | } // namespace detail |
72 | 107 |
|
@@ -133,5 +168,28 @@ static void add_dependencies(dpct::experimental::command_graph_ptr graph, |
133 | 168 | } |
134 | 169 | } |
135 | 170 |
|
| 171 | +/// Gets the nodes in the command graph. |
| 172 | +/// \param [in] graph A pointer to the command graph. |
| 173 | +/// \param [out] nodesArray An array of node pointers where the |
| 174 | +/// nodes will be assigned. |
| 175 | +/// \param [out] numberOfNodes The number of nodes in the graph. |
| 176 | +static void get_nodes(dpct::experimental::command_graph_ptr graph, |
| 177 | + dpct::experimental::node_ptr *nodesArray, |
| 178 | + std::size_t *numberOfNodes) { |
| 179 | + detail::graph_mgr::instance().get_nodes(graph, nodesArray, numberOfNodes); |
| 180 | +} |
| 181 | + |
| 182 | +/// Gets the root nodes in the command graph. |
| 183 | +/// \param [in] graph A pointer to the command graph. |
| 184 | +/// \param [out] nodesArray An array of node pointers where the |
| 185 | +/// root nodes will be assigned. |
| 186 | +/// \param [out] numberOfNodes The number of root nodes in the graph. |
| 187 | +static void get_root_nodes(dpct::experimental::command_graph_ptr graph, |
| 188 | + dpct::experimental::node_ptr *nodesArray, |
| 189 | + std::size_t *numberOfNodes) { |
| 190 | + detail::graph_mgr::instance().get_root_nodes(graph, nodesArray, |
| 191 | + numberOfNodes); |
| 192 | +} |
| 193 | + |
136 | 194 | } // namespace experimental |
137 | 195 | } // namespace dpct |
0 commit comments