Skip to content

Commit f1c2029

Browse files
authored
[SYCLomatic] Support migration of PTX instruction cp.async to sync operation(#2672)
Signed-off-by: chenwei.sun <chenwei.sun@intel.com>
1 parent 69193da commit f1c2029

9 files changed

Lines changed: 143 additions & 18 deletions

File tree

clang/lib/DPCT/Diagnostics/Diagnostics.inc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,9 @@ DEF_WARNING(JOINT_MATRIX_SHAPE, 1135, HIGH_LEVEL, "Please check if joint_matrix
298298
DEF_COMMENT(JOINT_MATRIX_SHAPE, 1135, HIGH_LEVEL, "Please check if joint_matrix implementations support the combination of data type and matrix shape type in the target hardware.")
299299
DEF_WARNING(UNSUPPORTED_EXTMEM_WIN_HANDLE, 1136, HIGH_LEVEL, "SYCL Bindless Images extension only supports importing external resource memory using NT handle on Windows. If assert(%0.get_win32_handle()) fails, you may need to adjust the code to use (%0.get_win32_handle()).")
300300
DEF_COMMENT(UNSUPPORTED_EXTMEM_WIN_HANDLE, 1136, HIGH_LEVEL, "SYCL Bindless Images extension only supports importing external resource memory using NT handle on Windows. If assert({0}.get_win32_handle()) fails, you may need to adjust the code to use ({0}.get_win32_handle()).")
301+
DEF_WARNING(ASYNC_COPY_DEVICE_WARN, 1137, LOW_LEVEL, "ASM instruction \"cp.async\" is asynchronous copy, current it is migrated to synchronous copy operation. You may need to adjust the code to tune the performance.")
302+
DEF_COMMENT(ASYNC_COPY_DEVICE_WARN, 1137, LOW_LEVEL, "ASM instruction \"cp.async\" is asynchronous copy, current it is migrated to synchronous copy operation. You may need to adjust the code to tune the performance.")
303+
301304
// clang-format on
302305

303306
#undef DEF_COMMENT

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ using namespace clang::dpct;
3737
namespace {
3838

3939
inline bool SYCLGenError() { return true; }
40-
inline bool SYCLGenSuccess() { return false; }
40+
inline bool SYCLGenSuccess() {return false; }
4141

4242
/// This is used to handle all the AST nodes (except specific instructions, Eg.
4343
/// mov/setp), and generate functionally equivalent SYCL code.
@@ -589,9 +589,11 @@ bool SYCLGenBase::emitVariableDeclaration(const InlineAsmVarDecl *D) {
589589

590590
bool SYCLGenBase::emitAddressExpr(const InlineAsmAddressExpr *Dst) {
591591
// Address expression only support ld/st/red & atom instructions.
592-
if (!CurrInst || !CurrInst->is(asmtok::op_st, asmtok::op_ld, asmtok::op_atom,
593-
asmtok::op_prefetch, asmtok::op_red))
592+
if (!CurrInst ||
593+
!CurrInst->is(asmtok::op_st, asmtok::op_ld, asmtok::op_atom,
594+
asmtok::op_prefetch, asmtok::op_red, asmtok::op_cp)) {
594595
return SYCLGenError();
596+
}
595597
std::string Type;
596598
if (tryEmitType(Type, CurrInst->getType(0)))
597599
return SYCLGenError();
@@ -618,6 +620,7 @@ bool SYCLGenBase::emitAddressExpr(const InlineAsmAddressExpr *Dst) {
618620
std::string Reg;
619621
if (tryEmitStmt(Reg, Dst->getSymbol()))
620622
return SYCLGenSuccess();
623+
621624
if (CurrInst->is(asmtok::op_prefetch, asmtok::op_red) ||
622625
CanSuppressCast(Dst->getSymbol()))
623626
OS() << llvm::formatv("{0}", Reg);
@@ -2769,6 +2772,46 @@ class SYCLGen : public SYCLGenBase {
27692772
endstmt();
27702773
return SYCLGenSuccess();
27712774
}
2775+
2776+
bool handle_cp(const InlineAsmInstruction *Inst) override {
2777+
if (Inst->getNumInputOperands() != 3 || Inst->getNumTypes() != 1)
2778+
return SYCLGenError();
2779+
2780+
llvm::SaveAndRestore<const InlineAsmInstruction *> Store(CurrInst);
2781+
CurrInst = Inst;
2782+
2783+
std::string Op[3];
2784+
for (int i = 0; i < 3; ++i)
2785+
if (tryEmitStmt(Op[i], Inst->getInputOperand(i)))
2786+
return SYCLGenError();
2787+
2788+
auto CommonIfStat = [&](std::string Val) {
2789+
indent();
2790+
return "if (" + Op[1] + " > " + Val + ")\n";
2791+
};
2792+
2793+
auto CommonBody = [&](std::string Val) {
2794+
incIndent();
2795+
indent();
2796+
decIndent();
2797+
return "*(" + Op[2] + " + " + Val + ") = *(" + Op[0] + " + " + Val + ")";
2798+
};
2799+
2800+
OS() << "*(" << Op[2] << ") = *(" << Op[0] << ");\n";
2801+
2802+
OS() << CommonIfStat("4");
2803+
OS() << CommonBody("1") << ";\n";
2804+
2805+
OS() << CommonIfStat("8");
2806+
OS() << CommonBody("2") << ";\n";
2807+
2808+
OS() << CommonIfStat("12");
2809+
OS() << CommonBody("3");
2810+
endstmt();
2811+
2812+
report(Diagnostics::ASYNC_COPY_DEVICE_WARN, true);
2813+
return SYCLGenSuccess();
2814+
}
27722815
};
27732816

27742817
/// Clean the special character in identifier.
@@ -2985,7 +3028,6 @@ void AsmRule::doMigrateInternel(const GCCAsmStmt *GAS) {
29853028
Parser.addInlineAsmOperands(GAS->getInputExpr(I),
29863029
getReplaceString(GAS->getInputExpr(I)),
29873030
GAS->getInputConstraint(I));
2988-
29893031
do {
29903032
auto Inst = Parser.ParseStatement();
29913033
if (Inst.isInvalid()) {

clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ class InlineAsmInstruction : public InlineAsmStmt {
318318
/// e.g. asmtok::op_mov, asmtok::op_setp, etc.
319319
InlineAsmIdentifierInfo *Opcode = nullptr;
320320

321-
std::optional<AsmStateSpace> StateSpace;
321+
SmallVector<AsmStateSpace, 4> StateSpaces;
322322

323323
/// This represents arrtibutes like: comparsion operator, rounding modifiers,
324324
/// ... e.g. instruction setp.eq.s32 has a comparsion operator 'eq'.
@@ -342,12 +342,14 @@ class InlineAsmInstruction : public InlineAsmStmt {
342342

343343
public:
344344
InlineAsmInstruction(InlineAsmIdentifierInfo *Op,
345-
std::optional<AsmStateSpace> SS,
345+
SmallVector<AsmStateSpace, 4> AsmStateSpaces,
346346
ArrayRef<InstAttr> Attrs,
347347
ArrayRef<InlineAsmType *> Types, InlineAsmExpr *Out,
348348
InlineAsmExpr *Pred, ArrayRef<InlineAsmExpr *> InOps)
349-
: InlineAsmStmt(InstructionClass), Opcode(Op), StateSpace(SS),
350-
Types(Types), OutputOp(Out), PredOutputOp(Pred), InputOps(InOps) {
349+
: InlineAsmStmt(InstructionClass), Opcode(Op), Types(Types),
350+
OutputOp(Out), PredOutputOp(Pred), InputOps(InOps) {
351+
StateSpaces.insert(StateSpaces.begin(), AsmStateSpaces.begin(),
352+
AsmStateSpaces.end());
351353
Attributes.insert(Attrs.begin(), Attrs.end());
352354
}
353355

@@ -390,6 +392,10 @@ class InlineAsmInstruction : public InlineAsmStmt {
390392
return InstructionClass <= S->getStmtClass();
391393
}
392394
AsmStateSpace getStateSpace() const {
395+
396+
std::optional<AsmStateSpace> StateSpace =
397+
StateSpaces.size() > 0 ? StateSpaces[StateSpaces.size() - 1]
398+
: AsmStateSpace::none;
393399
return StateSpace.value_or(AsmStateSpace::none);
394400
}
395401
};

clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ InlineAsmStmtResult InlineAsmParser::ParseInstruction() {
333333
SmallVector<InstAttr, 4> Attrs;
334334
SmallVector<InlineAsmType *, 4> Types;
335335
SmallVector<InlineAsmExpr *, 4> Ops;
336-
std::optional<AsmStateSpace> StateSpace;
336+
SmallVector<AsmStateSpace, 4> StateSpaces;
337337
while (Tok.startOfDot()) {
338338
switch (Tok.getIdentifier()->getFlags()) {
339339
case InlineAsmIdentifierInfo::BuiltinType:
@@ -343,11 +343,7 @@ InlineAsmStmtResult InlineAsmParser::ParseInstruction() {
343343
Attrs.push_back(ConvertToInstAttr(Tok.getKind()));
344344
break;
345345
case InlineAsmIdentifierInfo::StateSpace:
346-
// Duplicated state space in an single instruction statement.
347-
if (StateSpace.has_value())
348-
return AsmStmtError();
349-
else
350-
StateSpace = ConvertToStateSpace(Tok.getKind());
346+
StateSpaces.push_back(ConvertToStateSpace(Tok.getKind()));
351347
break;
352348
default:
353349
return AsmStmtError();
@@ -383,7 +379,13 @@ InlineAsmStmtResult InlineAsmParser::ParseInstruction() {
383379
Types.push_back(Context.getBuiltinType(InlineAsmBuiltinType::byte));
384380
}
385381

386-
return ::new (Context) InlineAsmInstruction(Opcode, StateSpace, Attrs, Types,
382+
if (Opcode->getTokenID() == asmtok::op_cp) {
383+
Ops.push_back(Out.get());
384+
Out = nullptr;
385+
Types.push_back(Context.getBuiltinType(InlineAsmBuiltinType::u32));
386+
}
387+
388+
return ::new (Context) InlineAsmInstruction(Opcode, StateSpaces, Attrs, Types,
387389
Out.get(), Pred.get(), Ops);
388390
}
389391

clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,8 @@ MODIFIER(clamp, ".clamp")
398398
MODIFIER(wrap, ".wrap")
399399
MODIFIER(wide, ".wide")
400400
MODIFIER(sync, ".sync")
401+
MODIFIER(async, ".async")
402+
MODIFIER(cg, ".cg")
401403
MODIFIER(warp, ".warp")
402404
MODIFIER(up, ".up")
403405
MODIFIER(down, ".down")

clang/lib/DPCT/SrcAPI/APINames_ASM.inc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ ENTRY("clz", "clz", true, NO_FLAG, P1, "Successful")
5656
ENTRY("cnot", "cnot", true, NO_FLAG, P1, "Successful")
5757
ENTRY("copysign", "copysign", true, NO_FLAG, P1, "Successful")
5858
ENTRY("cos", "cos", true, NO_FLAG, P1, "Successful")
59-
ENTRY("cp", "cp", false, NO_FLAG, P1, "Comment")
59+
ENTRY("cp", "cp", true, NO_FLAG, P1, "Partial")
6060
ENTRY("createpolicy", "createpolicy", false, NO_FLAG, P1, "Comment")
6161
ENTRY("cvt", "cvt", true, NO_FLAG, P1, "Partial")
6262
ENTRY("cvta", "cvta", false, NO_FLAG, P1, "Comment")

clang/test/dpct/asm/cp.cu

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// UNSUPPORTED: cuda-8.0, cuda-9.0, cuda-9.1, cuda-9.2, cuda-10.0, cuda-10.1, cuda-10.2
2+
// UNSUPPORTED: v8.0, v9.0, v9.1, v9.2, v10.0, v10.1, v10.2
3+
// RUN: dpct --format-range=none -out-root %T/cp %s --cuda-include-path="%cuda-path/include" -- -std=c++14 -x cuda --cuda-host-only
4+
// RUN: FileCheck %s --match-full-lines --input-file %T/cp/cp.dp.cpp
5+
// RUN: %if build_lit %{icpx -c -fsycl %T/cp/cp.dp.cpp -o %T/cp/cp.dp.o %}
6+
7+
// clang-format off
8+
#include <cstdint>
9+
#include <cstdint>
10+
#include <cuda_runtime.h>
11+
12+
// CHECK: inline void cp_async4(void *smem_ptr, const void *glob_ptr) {
13+
// CHECK-NEXT: const int BYTES = 16;
14+
// CHECK-NEXT: auto smem = smem_ptr;
15+
// CHECK-NEXT: /*
16+
// CHECK-NEXT: DPCT1137:{{[0-9]+}}: ASM instruction "cp.async" is asynchronous copy, current it is migrated to synchronous copy operation. You may need to adjust the code to tune the performance.
17+
// CHECK-NEXT: */
18+
// CHECK-NEXT: {
19+
// CHECK-NEXT: *(((uint32_t *)(uintptr_t)smem)) = *(((uint32_t *)(uintptr_t)glob_ptr));
20+
// CHECK-NEXT: if (BYTES > 4)
21+
// CHECK-NEXT: *(((uint32_t *)(uintptr_t)smem) + 1) = *(((uint32_t *)(uintptr_t)glob_ptr) + 1);
22+
// CHECK-NEXT: if (BYTES > 8)
23+
// CHECK-NEXT: *(((uint32_t *)(uintptr_t)smem) + 2) = *(((uint32_t *)(uintptr_t)glob_ptr) + 2);
24+
// CHECK-NEXT: if (BYTES > 12)
25+
// CHECK-NEXT: *(((uint32_t *)(uintptr_t)smem) + 3) = *(((uint32_t *)(uintptr_t)glob_ptr) + 3);
26+
// CHECK-NEXT: }
27+
// CHECK-NEXT:}
28+
__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) {
29+
const int BYTES = 16;
30+
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
31+
asm volatile("{\n"
32+
" cp.async.cg.shared.global [%0], [%1], %2;\n"
33+
"}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES));
34+
}
35+
36+
37+
// CHECK: inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr,
38+
// CHECK-NEXT: bool pred = true) {
39+
// CHECK-NEXT: const int BYTES = 16;
40+
// CHECK-NEXT: auto smem = smem_ptr;
41+
// CHECK-NEXT: /*
42+
// CHECK-NEXT: DPCT1137:{{[0-9]+}}: ASM instruction "cp.async" is asynchronous copy, current it is migrated to synchronous copy operation. You may need to adjust the code to tune the performance.
43+
// CHECK-NEXT: */
44+
// CHECK-NEXT: {
45+
// CHECK-NEXT: bool p;
46+
// CHECK-NEXT: p = (int)pred != 0;
47+
// CHECK-NEXT: if (p) {
48+
// CHECK-NEXT: *(((uint32_t *)(uintptr_t)smem)) = *(((uint32_t *)(uintptr_t)glob_ptr));
49+
// CHECK-NEXT: if (BYTES > 4)
50+
// CHECK-NEXT: *(((uint32_t *)(uintptr_t)smem) + 1) = *(((uint32_t *)(uintptr_t)glob_ptr) + 1);
51+
// CHECK-NEXT: if (BYTES > 8)
52+
// CHECK-NEXT: *(((uint32_t *)(uintptr_t)smem) + 2) = *(((uint32_t *)(uintptr_t)glob_ptr) + 2);
53+
// CHECK-NEXT: if (BYTES > 12)
54+
// CHECK-NEXT: *(((uint32_t *)(uintptr_t)smem) + 3) = *(((uint32_t *)(uintptr_t)glob_ptr) + 3);
55+
// CHECK-NEXT: }
56+
// CHECK-NEXT: }
57+
// CHECK-NEXT:}
58+
__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr,
59+
bool pred = true) {
60+
const int BYTES = 16;
61+
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
62+
asm volatile("{\n"
63+
" .reg .pred p;\n"
64+
" setp.ne.b32 p, %0, 0;\n"
65+
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
66+
"}\n" ::"r"((int)pred),
67+
"r"(smem), "l"(glob_ptr), "n"(BYTES));
68+
}
69+
70+
// clang-format on

clang/test/dpct/help_option_check/lin/help_all.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ All DPCT options
127127
--rule-file=<file> - Specify the rule file for migration. Also, reference the predefined rules in the "extensions" directory in the root folder of the tool.
128128
--stop-on-parse-err - Stop migration and generation of reports if parsing errors happened. Default: off.
129129
--suppress-warnings=<value> - A comma separated list of migration warnings to suppress. Valid warning IDs range
130-
from 1000 to 1136. Hyphen separated ranges are also allowed. For example:
130+
from 1000 to 1137. Hyphen separated ranges are also allowed. For example:
131131
--suppress-warnings=1000-1010,1011.
132132
--suppress-warnings-all - Suppress all migration warnings. Default: off.
133133
--sycl-file-extension=<value> - Specify the extension of migrated source file(s).

clang/test/dpct/help_option_check/win/help_all.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ All DPCT options
126126
--rule-file=<file> - Specify the rule file for migration. Also, reference the predefined rules in the "extensions" directory in the root folder of the tool.
127127
--stop-on-parse-err - Stop migration and generation of reports if parsing errors happened. Default: off.
128128
--suppress-warnings=<value> - A comma separated list of migration warnings to suppress. Valid warning IDs range
129-
from 1000 to 1136. Hyphen separated ranges are also allowed. For example:
129+
from 1000 to 1137. Hyphen separated ranges are also allowed. For example:
130130
--suppress-warnings=1000-1010,1011.
131131
--suppress-warnings-all - Suppress all migration warnings. Default: off.
132132
--sycl-file-extension=<value> - Specify the extension of migrated source file(s).

0 commit comments

Comments
 (0)