@@ -556,6 +556,9 @@ bool SYCLGenBase::emitVectorType(const InlineAsmVectorType *T) {
556556 return SYCLGenError ();
557557 OS () << " , " ;
558558 switch (T->getKind ()) {
559+ case InlineAsmVectorType::v1:
560+ OS () << 1 ;
561+ break ;
559562 case InlineAsmVectorType::v2:
560563 OS () << 2 ;
561564 break ;
@@ -1342,7 +1345,7 @@ class SYCLGen : public SYCLGenBase {
13421345 return SYCLGenError ();
13431346
13441347 // Register sizes for vector elements of A, B, C & D matrices
1345- int NumVecElements[4 ] = {0 };
1348+ unsigned NumVecElements[4 ] = {0 };
13461349
13471350 // Data type used to multiply A & B matrices
13481351 std::string MulType;
@@ -1351,8 +1354,8 @@ class SYCLGen : public SYCLGenBase {
13511354 if (AType->getKind () == InlineAsmBuiltinType::f16 ) {
13521355 // If A matrix type is f16, then C&D matrix types can only be f16
13531356 if (CType->getKind () == AType->getKind ()) {
1354- NumVecElements[0 ] = 4 ; // A
1355- NumVecElements[1 ] = 2 ; // B
1357+ NumVecElements[0 ] = 2 ; // A
1358+ NumVecElements[1 ] = 4 ; // B
13561359 NumVecElements[2 ] = 4 ; // C
13571360 NumVecElements[3 ] = 4 ; // D
13581361 } else
@@ -1364,23 +1367,23 @@ class SYCLGen : public SYCLGenBase {
13641367 if (AType->getKind () == InlineAsmBuiltinType::f16 ) {
13651368 // If A matrix type is f16, then C&D matrix types can only be f16/f32
13661369 if (CType->getKind () == AType->getKind ()) {
1367- NumVecElements[0 ] = 4 ; // A
1370+ NumVecElements[0 ] = 2 ; // A
13681371 NumVecElements[1 ] = 2 ; // B
1369- NumVecElements[2 ] = 2 ; // C
1372+ NumVecElements[2 ] = 4 ; // C
13701373 NumVecElements[3 ] = 4 ; // D
13711374 } else if (CType->getKind () == InlineAsmBuiltinType::f32 ) {
1372- NumVecElements[0 ] = 8 ; // A
1375+ NumVecElements[0 ] = 2 ; // A
13731376 NumVecElements[1 ] = 2 ; // B
1374- NumVecElements[2 ] = 2 ; // C
1377+ NumVecElements[2 ] = 8 ; // C
13751378 NumVecElements[3 ] = 8 ; // D
13761379 } else
13771380 return SYCLGenError ();
13781381 } else if (AType->getKind () == InlineAsmBuiltinType::f64 ) {
13791382 // If A matrix type is f64, then C&D matrix types can only be f64
13801383 if (CType->getKind () == AType->getKind ()) {
1381- NumVecElements[0 ] = 2 ; // A
1384+ NumVecElements[0 ] = 1 ; // A
13821385 NumVecElements[1 ] = 1 ; // B
1383- NumVecElements[2 ] = 1 ; // C
1386+ NumVecElements[2 ] = 2 ; // C
13841387 NumVecElements[3 ] = 2 ; // D
13851388 } else
13861389 return SYCLGenError ();
@@ -1411,15 +1414,16 @@ class SYCLGen : public SYCLGenBase {
14111414 if (isa<InlineAsmDiscardExpr>(DMatVE->getElement (Inst)))
14121415 continue ;
14131416 OS () << " &" ;
1414- if (emitStmt (VE ->getElement (Inst)))
1417+ if (emitStmt (DMatVE ->getElement (Inst)))
14151418 return SYCLGenError ();
14161419 OS () << " , " ;
14171420 }
14181421
14191422 // Add A, B & C matrix values to compute MAD
14201423 for (unsigned InputOp = 0 ; InputOp < Inst->getNumInputOperands ();
14211424 InputOp++) {
1422- if (VE = dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand (InputOp))) {
1425+ if (auto VE =
1426+ dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand (InputOp))) {
14231427 for (unsigned Inst = 0 ; Inst != VE->getNumElements (); ++Inst) {
14241428 if (isa<InlineAsmDiscardExpr>(VE->getElement (Inst)))
14251429 continue ;
0 commit comments