Skip to content

Commit ff384c9

Browse files
authored
[SYCLomatic] Enable the migration of 2 API cub::BlockExchange.WarpStripedToBlocked/BlockedToWarpStriped with help function(#2679)
Signed-off-by: intwanghao <hao3.wang@intel.com>
1 parent a653161 commit ff384c9

6 files changed

Lines changed: 216 additions & 7 deletions

File tree

clang/lib/DPCT/RulesLang/RewriterSYCLcompat.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ SYCLCOMPAT_UNSUPPORT("cub::BlockExchange.BlockedToStriped")
9494
SYCLCOMPAT_UNSUPPORT("cub::BlockExchange.StripedToBlocked")
9595
SYCLCOMPAT_UNSUPPORT("cub::BlockExchange.ScatterToBlocked")
9696
SYCLCOMPAT_UNSUPPORT("cub::BlockExchange.ScatterToStriped")
97+
SYCLCOMPAT_UNSUPPORT("cub::BlockExchange.WarpStripedToBlocked")
98+
SYCLCOMPAT_UNSUPPORT("cub::BlockExchange.BlockedToWarpStriped")
9799
SYCLCOMPAT_UNSUPPORT("cub::BlockShuffle.Offset")
98100
SYCLCOMPAT_UNSUPPORT("cub::BlockShuffle.Rotate")
99101
SYCLCOMPAT_UNSUPPORT("cub::BlockShuffle.Up")

clang/lib/DPCT/RulesLangLib/CUB/RewriterClassMethods.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,28 @@ RewriterMap dpct::createClassMethodsRewriterMap() {
211211
"cub::BlockExchange.ScatterToStriped",
212212
MemberExprBase(), false, "scatter_to_striped",
213213
NDITEM, ARG(0), ARG(1)))
214+
// cub::BlockExchange.BlockedToWarpStriped
215+
SUBGROUPSIZE_FACTORY(
216+
UINT_MAX,
217+
MapNames::getDpctNamespace() +
218+
"exchange.blocked_to_sub_group_striped",
219+
HEADER_INSERT_FACTORY(HeaderType::HT_DPCT_GROUP_Utils,
220+
MEMBER_CALL_FACTORY_ENTRY(
221+
"cub::BlockExchange.BlockedToWarpStriped",
222+
MemberExprBase(), false,
223+
"blocked_to_sub_group_striped", NDITEM,
224+
ARG(0), ARG(1))))
225+
// cub::BlockExchange.WarpStripedToBlocked
226+
SUBGROUPSIZE_FACTORY(
227+
UINT_MAX,
228+
MapNames::getDpctNamespace() +
229+
"exchange.sub_group_striped_to_blocked",
230+
HEADER_INSERT_FACTORY(HeaderType::HT_DPCT_GROUP_Utils,
231+
MEMBER_CALL_FACTORY_ENTRY(
232+
"cub::BlockExchange.WarpStripedToBlocked",
233+
MemberExprBase(), false,
234+
"sub_group_striped_to_blocked", NDITEM,
235+
ARG(0), ARG(1))))
214236
// cub::BlockShuffle.Offset
215237
HEADER_INSERT_FACTORY(
216238
HeaderType::HT_DPCT_GROUP_Utils,

clang/lib/DPCT/RulesLangLib/CUBAPIMigration.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ void CubMemberCallRule::registerMatcher(ast_matchers::MatchFinder &MF) {
167167
"normalize", "Sort", "SortDescending", "BlockedToStriped",
168168
"StripedToBlocked", "ScatterToBlocked", "ScatterToStriped",
169169
"SortBlockedToStriped", "SortDescendingBlockedToStriped",
170-
"Load", "Store", "Offset", "Rotate", "Up", "Down")))))
170+
"Load", "Store", "Offset", "Rotate", "Up", "Down",
171+
"BlockedToWarpStriped", "WarpStripedToBlocked")))))
171172
.bind("memberCall"),
172173
this);
173174

@@ -253,7 +254,8 @@ void CubMemberCallRule::runRule(
253254
bool isBlockExchange =
254255
Name == "BlockedToStriped" || Name == "StripedToBlocked" ||
255256
Name == "StripedToBlocked" || Name == "ScatterToBlocked" ||
256-
Name == "ScatterToStriped";
257+
Name == "ScatterToStriped" || Name == "WarpStripedToBlocked" ||
258+
Name == "BlockedToWarpStriped";
257259
bool isBlockShuffle =
258260
Name == "Offset" || Name == "Rotate" || Name == "Up" || Name == "Down";
259261
if (isBlockRadixSort || isBlockExchange || isBlockShuffle ||

clang/lib/DPCT/SrcAPI/APINames_CUB.inc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ ENTRY_MEMBER_FUNCTION(cub::BlockDiscontinuity, cub::BlockDiscontinuity, FlagTail
9898
ENTRY_MEMBER_FUNCTION(cub::BlockDiscontinuity, cub::BlockDiscontinuity, FlagHeadsAndTails, FlagHeadsAndTails, false, NO_FLAG, P4, "Comment")
9999
ENTRY_MEMBER_FUNCTION(cub::BlockExchange, cub::BlockExchange, StripedToBlocked, StripedToBlocked, true, NO_FLAG, P4, "Successful")
100100
ENTRY_MEMBER_FUNCTION(cub::BlockExchange, cub::BlockExchange, BlockedToStriped, BlockedToStriped, true, NO_FLAG, P4, "Successful")
101-
ENTRY_MEMBER_FUNCTION(cub::BlockExchange, cub::BlockExchange, WarpStripedToBlocked, WarpStripedToBlocked, false, NO_FLAG, P4, "Comment")
102-
ENTRY_MEMBER_FUNCTION(cub::BlockExchange, cub::BlockExchange, BlockedToWarpStriped, BlockedToWarpStriped, false, NO_FLAG, P4, "Comment")
101+
ENTRY_MEMBER_FUNCTION(cub::BlockExchange, cub::BlockExchange, WarpStripedToBlocked, WarpStripedToBlocked, true, NO_FLAG, P4, "Comment")
102+
ENTRY_MEMBER_FUNCTION(cub::BlockExchange, cub::BlockExchange, BlockedToWarpStriped, BlockedToWarpStriped, true, NO_FLAG, P4, "Comment")
103103
ENTRY_MEMBER_FUNCTION(cub::BlockExchange, cub::BlockExchange, ScatterToBlocked, ScatterToBlocked, true, NO_FLAG, P4, "Successful")
104104
ENTRY_MEMBER_FUNCTION(cub::BlockExchange, cub::BlockExchange, ScatterToStriped, ScatterToStriped, true, NO_FLAG, P4, "Successful")
105105
ENTRY_MEMBER_FUNCTION(cub::BlockExchange, cub::BlockExchange, ScatterToStripedGuarded, ScatterToStripedGuarded, false, NO_FLAG, P4, "Comment")

clang/runtime/dpct-rt/include/dpct/group_utils.hpp

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,25 @@ template <typename T, size_t ElementsPerWorkItem> class exchange {
5656

5757
struct striped_offset {
5858
template <typename Item> size_t operator()(Item item, size_t i) {
59-
size_t offset = i * item.get_local_range(2) * item.get_local_range(1) *
60-
item.get_local_range(0) +
59+
size_t offset = i * item.get_group().get_local_linear_range() +
6160
item.get_local_linear_id();
6261
return adjust_by_padding(offset);
6362
}
6463
};
6564

65+
struct sub_group_striped_offset {
66+
template <typename Item> size_t operator()(Item item, size_t i) {
67+
auto sg = item.get_sub_group();
68+
size_t wg_size = item.get_group().get_local_linear_range();
69+
size_t sub_group_sliced_items =
70+
std::min<size_t>(sg.get_local_linear_range(), wg_size);
71+
size_t offset = (sg.get_group_linear_id() * sub_group_sliced_items *
72+
ElementsPerWorkItem) +
73+
(i * sub_group_sliced_items) + sg.get_local_linear_id();
74+
return adjust_by_padding(offset);
75+
}
76+
};
77+
6678
template <typename Iterator> struct scatter_offset {
6779
Iterator begin;
6880
scatter_offset(const int (&ranks)[ElementsPerWorkItem]) {
@@ -241,6 +253,60 @@ template <typename T, size_t ElementsPerWorkItem> class exchange {
241253
helper_exchange(item, input, input, get_scatter_offset, get_striped_offset);
242254
}
243255

256+
/// Rearrange elements from blocked order to sub_group striped order.
257+
///
258+
/// Suppose 512 integer data elements partitioned across 128 work-items, where
259+
/// each work-item owns 4 ( \p ElementsPerWorkItem ) data elements and the
260+
/// blocked \p input across the work-group is:
261+
///
262+
/// { [0, 1, 2, 3], [4, 5, 6, 7], ..., [508, 509, 510, 511] }.
263+
///
264+
/// The sub_group striped order output (with sub_group size 2) is:
265+
///
266+
/// { [0, 4, 1, 5], [2, 6, 3, 7], [8, 12, 9, 13], [10, 14, 11, 15], ...
267+
/// , [506, 510, 507, 511] }.
268+
///
269+
/// \tparam Item The work-item identifier type.
270+
/// \param item The work-item identifier.
271+
/// \param input The input data of each work-item.
272+
/// \param output The corresponding output data of each work-item.
273+
template <typename Item>
274+
__dpct_inline__ void
275+
blocked_to_sub_group_striped(Item item, T (&input)[ElementsPerWorkItem],
276+
T (&output)[ElementsPerWorkItem]) {
277+
blocked_offset get_blocked_offset;
278+
sub_group_striped_offset get_sub_group_striped_offset;
279+
helper_exchange(item, input, output, get_blocked_offset,
280+
get_sub_group_striped_offset);
281+
}
282+
283+
/// Rearrange elements from sub_group striped order to blocked order.
284+
///
285+
/// Suppose 512 integer data elements partitioned across 128 work-items, where
286+
/// each work-item owns 4 ( \p ElementsPerWorkItem ) data elements and the
287+
/// sub_group striped \p input across the work-group is:
288+
///
289+
/// { [0, 4, 1, 5], [2, 6, 3, 7], [8, 12, 9, 13], [10, 14, 11, 15], ...
290+
/// , [506, 510, 507, 511] }.
291+
///
292+
/// The blocked order output is:
293+
///
294+
/// { [0, 1, 2, 3], [4, 5, 6, 7], ..., [508, 509, 510, 511] }.
295+
///
296+
/// \tparam Item The work-item identifier type.
297+
/// \param item The work-item identifier.
298+
/// \param input The input data of each work-item.
299+
/// \param output The corresponding output data of each work-item.
300+
template <typename Item>
301+
__dpct_inline__ void
302+
sub_group_striped_to_blocked(Item item, T (&input)[ElementsPerWorkItem],
303+
T (&output)[ElementsPerWorkItem]) {
304+
blocked_offset get_blocked_offset;
305+
sub_group_striped_offset get_sub_group_striped_offset;
306+
helper_exchange(item, input, output, get_sub_group_striped_offset,
307+
get_blocked_offset);
308+
}
309+
244310
private:
245311
template <typename Item, typename offsetFunctorTypeFW,
246312
typename offsetFunctorTypeRV>

clang/test/dpct/cub/blocklevel/blockexchange.cu

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,34 @@ __global__ void ScatterToStripedKernel(int *d_data, int *d_rank) {
8484
cub::StoreDirectStriped<128>(threadIdx.x, d_data, thread_data);
8585
}
8686

87+
__global__ void BlockedToWarpStripedKernel(int *d_data) {
88+
// CHECK: typedef dpct::group::exchange<int, 4> BlockExchange;
89+
// CHECK: int thread_data[4];
90+
// CHECK: dpct::group::load_direct_blocked(item_ct1, d_data, thread_data);
91+
// CHECK: BlockExchange(temp_storage).blocked_to_sub_group_striped(item_ct1, thread_data, thread_data);
92+
// CHECK: dpct::group::store_direct_blocked(item_ct1, d_data, thread_data);
93+
typedef cub::BlockExchange<int, 128, 4> BlockExchange;
94+
__shared__ typename BlockExchange::TempStorage temp_storage;
95+
int thread_data[4];
96+
cub::LoadDirectBlocked(threadIdx.x, d_data, thread_data);
97+
BlockExchange(temp_storage).BlockedToWarpStriped(thread_data, thread_data);
98+
cub::StoreDirectBlocked(threadIdx.x, d_data, thread_data);
99+
}
100+
101+
__global__ void WarpStripedToBlockedKernel(int *d_data) {
102+
// CHECK: typedef dpct::group::exchange<int, 4> BlockExchange;
103+
// CHECK: int thread_data[4];
104+
// CHECK: dpct::group::load_direct_blocked(item_ct1, d_data, thread_data);
105+
// CHECK: BlockExchange(temp_storage).sub_group_striped_to_blocked(item_ct1, thread_data, thread_data);
106+
// CHECK: dpct::group::store_direct_blocked(item_ct1, d_data, thread_data);
107+
typedef cub::BlockExchange<int, 128, 4> BlockExchange;
108+
__shared__ typename BlockExchange::TempStorage temp_storage;
109+
int thread_data[4];
110+
cub::LoadDirectBlocked(threadIdx.x, d_data, thread_data);
111+
BlockExchange(temp_storage).WarpStripedToBlocked(thread_data, thread_data);
112+
cub::StoreDirectBlocked(threadIdx.x, d_data, thread_data);
113+
}
114+
87115
bool test_striped_to_blocked() {
88116
int *d_data;
89117
cudaMallocManaged(&d_data, sizeof(int) * 512);
@@ -257,7 +285,96 @@ bool test_scatter_to_striped() {
257285
return true;
258286
}
259287

288+
bool test_blocked_to_warp_striped() {
289+
int *d_data, expected[512];
290+
cudaMallocManaged(&d_data, sizeof(int) * 512);
291+
for (int i = 0; i < 512; ++i)
292+
d_data[i] = i;
293+
294+
295+
// CHECK: q_ct1.submit(
296+
// CHECK-NEXT: [&](sycl::handler &cgh) {
297+
// CHECK-NEXT: sycl::local_accessor<uint8_t, 1> temp_storage_acc(dpct::group::exchange<int, 4>::get_local_memory_size(sycl::range<3>(1, 1, 128).size()), cgh);
298+
// CHECK-EMPTY:
299+
// CHECK-NEXT: cgh.parallel_for(
300+
// CHECK-NEXT: sycl::nd_range<3>(sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)),
301+
// CHECK-NEXT: [=](sycl::nd_item<3> item_ct1) {{\[\[}}sycl::reqd_sub_group_size(32){{\]\]}} {
302+
// CHECK-NEXT: BlockedToWarpStripedKernel(d_data, item_ct1, &temp_storage_acc[0]);
303+
// CHECK-NEXT: });
304+
// CHECK-NEXT: });
305+
BlockedToWarpStripedKernel<<<1, 128>>>(d_data);
306+
cudaDeviceSynchronize();
307+
size_t warp_id = 0, warp_offset = 0, lane_id = 0;
308+
for (int i = 0; i < 128; i++) {
309+
warp_id = i / 32;
310+
lane_id = i % 32;
311+
warp_offset = warp_id * 32 * 4;
312+
expected[4 * i + 0] = warp_offset + lane_id + 0 * 32;
313+
expected[4 * i + 1] = warp_offset + lane_id + 1 * 32;
314+
expected[4 * i + 2] = warp_offset + lane_id + 2 * 32;
315+
expected[4 * i + 3] = warp_offset + lane_id + 3 * 32;
316+
}
317+
318+
for (int i = 0; i < 512; ++i) {
319+
if (expected[i] != d_data[i]) {
320+
std::cout << "test_blocked_to_warp_striped failed\n";
321+
std::ostream_iterator<int> Iter(std::cout, ", ");
322+
std::copy(d_data, d_data + 512, Iter);
323+
std::cout << std::endl;
324+
std::copy(expected, expected + 512, Iter);
325+
std::cout << std::endl;
326+
return false;
327+
}
328+
}
329+
std::cout << "test_blocked_to_warp_striped pass\n";
330+
return true;
331+
}
332+
333+
bool test_warp_striped_to_blocked() {
334+
int *d_data, expected[512];
335+
cudaMallocManaged(&d_data, sizeof(int) * 512);
336+
size_t warp_id = 0, warp_offset = 0, lane_id = 0;
337+
for (int i = 0; i < 128; i++) {
338+
warp_id = i / 32;
339+
lane_id = i % 32;
340+
warp_offset = warp_id * 32 * 4;
341+
d_data[4 * i + 0] = warp_offset + lane_id + 0 * 32;
342+
d_data[4 * i + 1] = warp_offset + lane_id + 1 * 32;
343+
d_data[4 * i + 2] = warp_offset + lane_id + 2 * 32;
344+
d_data[4 * i + 3] = warp_offset + lane_id + 3 * 32;
345+
}
346+
// CHECK: q_ct1.submit(
347+
// CHECK-NEXT: [&](sycl::handler &cgh) {
348+
// CHECK-NEXT: sycl::local_accessor<uint8_t, 1> temp_storage_acc(dpct::group::exchange<int, 4>::get_local_memory_size(sycl::range<3>(1, 1, 128).size()), cgh);
349+
// CHECK-EMPTY:
350+
// CHECK-NEXT: cgh.parallel_for(
351+
// CHECK-NEXT: sycl::nd_range<3>(sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)),
352+
// CHECK-NEXT: [=](sycl::nd_item<3> item_ct1) {{\[\[}}sycl::reqd_sub_group_size(32){{\]\]}} {
353+
// CHECK-NEXT: WarpStripedToBlockedKernel(d_data, item_ct1, &temp_storage_acc[0]);
354+
// CHECK-NEXT: });
355+
// CHECK-NEXT: });
356+
WarpStripedToBlockedKernel<<<1, 128>>>(d_data);
357+
cudaDeviceSynchronize();
358+
359+
for (int i = 0; i < 512; i++) {
360+
expected[i] = i;
361+
}
362+
363+
for (int i = 0; i < 512; ++i) {
364+
if (expected[i] != d_data[i]) {
365+
std::cout << "test_warp_striped_to_blocked failed\n";
366+
std::ostream_iterator<int> Iter(std::cout, ", ");
367+
std::copy(d_data, d_data + 512, Iter);
368+
std::cout << std::endl;
369+
return false;
370+
}
371+
}
372+
std::cout << "test_warp_striped_to_blocked pass\n";
373+
return true;
374+
}
375+
260376
int main() {
261377
return !(test_blocked_to_striped() && test_striped_to_blocked() &&
262-
test_scatter_to_blocked() && test_scatter_to_striped());
378+
test_scatter_to_blocked() && test_scatter_to_striped() &&
379+
test_blocked_to_warp_striped() && test_warp_striped_to_blocked());
263380
}

0 commit comments

Comments
 (0)