Skip to content

Commit 9b620bf

Browse files
author
the-slow-one
authored
[SYCLomatic]Migrate types used while casting stream and events (#2686)
1 parent 903eeab commit 9b620bf

2 files changed

Lines changed: 18 additions & 4 deletions

File tree

clang/lib/DPCT/RulesLang/RulesLang.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4367,11 +4367,12 @@ void StreamAPICallRule::runRule(const MatchFinder::MatchResult &Result) {
43674367
StreamName = "{{NEEDREPLACEQ" + std::to_string(Index) + "}}.";
43684368
ReplStr = StreamName + "ext_oneapi_empty()";
43694369
} else {
4370-
StreamName = getStmtSpelling(StreamArg);
4370+
ExprAnalysis EA(StreamArg);
4371+
ReplStr = EA.getReplacedString();
43714372
if (needExtraParensInMemberExpr(StreamArg)) {
4372-
StreamName = "(" + StreamName + ")";
4373+
ReplStr = "(" + ReplStr + ")";
43734374
}
4374-
ReplStr = StreamName + "->" + "ext_oneapi_empty()";
4375+
ReplStr = ReplStr + "->" + "ext_oneapi_empty()";
43754376
}
43764377
if (IsAssigned) {
43774378
ReplStr = MapNames::getCheckErrorMacroName() + "((" + ReplStr + "))";
@@ -4414,7 +4415,12 @@ void StreamAPICallRule::runRule(const MatchFinder::MatchResult &Result) {
44144415

44154416
StmtStr0 = "{{NEEDREPLACEQ" + std::to_string(Index) + "}}.";
44164417
} else {
4417-
StmtStr0 = getStmtSpelling(CE->getArg(0)) + "->";
4418+
ExprAnalysis StreamArgEA(StreamArg);
4419+
StmtStr0 = StreamArgEA.getReplacedString();
4420+
if (needExtraParensInMemberExpr(StreamArg)) {
4421+
StmtStr0 = "(" + StmtStr0 + ")";
4422+
}
4423+
StmtStr0 += "->";
44184424
}
44194425
ReplStr = StmtStr0 + "ext_oneapi_submit_barrier({" +
44204426
StmtStr1 + "})";

clang/test/dpct/driver-stream-and-event.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ void foo(){
2929
CUresult streamStatus = cuStreamQuery(s);
3030
if (streamStatus == CUDA_SUCCESS);
3131

32+
unsigned long long stream_addr;
33+
// CHECK: ((dpct::queue_ptr)stream_addr)->ext_oneapi_empty();
34+
cuStreamQuery((CUstream)stream_addr);
35+
3236
//CHECK: s->wait();
3337
cuStreamSynchronize(s);
3438

@@ -39,6 +43,10 @@ void foo(){
3943
cuEventCreate(&e, CU_EVENT_DEFAULT);
4044
cuStreamWaitEvent(s, e, 0);
4145

46+
unsigned long long event_addr;
47+
// CHECK: ((dpct::queue_ptr)stream_addr)->ext_oneapi_submit_barrier({*(dpct::event_ptr)event_addr});
48+
cuStreamWaitEvent((CUstream)stream_addr, (CUevent)event_addr, 0);
49+
4250
//CHECK: /*
4351
//CHECK-NEXT: DPCT1012:{{[0-9]+}}: Detected kernel execution time measurement pattern and generated an initial code for time measurements in SYCL. You can change the way time is measured depending on your goals.
4452
//CHECK-NEXT: */

0 commit comments

Comments
 (0)