Skip to content

Commit 720acdf

Browse files
authored
[SM6.10][specs/791] Align LinAlg Matrix Mul with spec (microsoft#8183)
microsoft/hlsl-specs#791 splits the MulOp operation into two distinct ops. Update the placeholder code to reflect that. Interesting changes are in `gen_intrin_main`, `hctdb.py`, and `HLOperationLower.cpp`. The rest of the code is generated code
1 parent 435dbb6 commit 720acdf

7 files changed

Lines changed: 113 additions & 54 deletions

File tree

docs/DXIL.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3076,7 +3076,7 @@ ID Name Description
30763076
2147483656 RayQuery_CandidateTriangleObjectPosition returns candidate triangle vertices in object space as <9 x float>
30773077
2147483657 RayQuery_CommittedTriangleObjectPosition returns committed triangle vertices in object space as <9 x float>
30783078
2147483658 HitObject_TriangleObjectPosition returns triangle vertices in object space as <9 x float>
3079-
2147483659 ReservedD0 reserved
3079+
2147483659 LinAlgMatrixMultiplyAccumulate Returns the resulting matrix from multiplying A and B and accumulating into C
30803080
2147483660 LinAlgFillMatrix fills a matrix with a scalar value
30813081
2147483661 LinAlgCopyConvertMatrix Converts and copies the element and use type of the source matrix to the destination matrix with optional transpose
30823082
2147483662 LinAlgMatrixLoadFromDescriptor fills a matrix with data from a [RW]ByteAddressBuffer
@@ -3088,7 +3088,7 @@ ID Name Description
30883088
2147483668 LinAlgMatrixStoreToDescriptor stores a matrix to a RWByteAddressBuffer
30893089
2147483669 LinAlgMatrixStoreToMemory stores a matrix to groupshared memory
30903090
2147483670 LinAlgMatrixQueryAccumulatorLayout returns comptime 0 when accumulator matrix are A layout, 1 when B layout
3091-
2147483671 LinAlgMatrixMulOp applies a multiplication op to matrix C using A and B as parameters
3091+
2147483671 LinAlgMatrixMultiply Returns the resulting matrix from multiplying A and B
30923092
2147483672 LinAlgMatrixAccumulate accumulate A or B matrix into Accumulator matrix following LHS += RHS
30933093
2147483673 LinAlgMatVecMul Multiplies a MxK dimension matrix and a K sized input vector
30943094
2147483674 LinAlgMatVecMulAdd Multiplies a MxK dimension matrix and a K sized input vector then adds a M sized bias vector

include/dxc/DXIL/DxilConstants.h

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,6 @@ static const OpCodeTableID TableID = OpCodeTableID::ExperimentalOps;
524524
// Enumeration for ExperimentalOps DXIL operations
525525
enum class OpCode : unsigned {
526526
//
527-
ReservedD0 = 11, // reserved
528527
ReservedD1 = 30, // reserved
529528
ReservedD2 = 31, // reserved
530529
ReservedD3 = 32, // reserved
@@ -573,8 +572,11 @@ enum class OpCode : unsigned {
573572
14, // fills a matrix with data from a [RW]ByteAddressBuffer
574573
LinAlgMatrixLoadFromMemory =
575574
15, // fills a matrix with data from a groupshared array
576-
LinAlgMatrixMulOp =
577-
23, // applies a multiplication op to matrix C using A and B as parameters
575+
LinAlgMatrixMultiply =
576+
23, // Returns the resulting matrix from multiplying A and B
577+
LinAlgMatrixMultiplyAccumulate =
578+
11, // Returns the resulting matrix from multiplying A and B and
579+
// accumulating into C
578580
LinAlgMatrixOuterProduct = 29, // Outer products an M sized vector and a N
579581
// sized vector producing an MxN matrix
580582
LinAlgMatrixQueryAccumulatorLayout =
@@ -1263,8 +1265,11 @@ enum class OpCode : unsigned {
12631265
EXP_OPCODE(ExperimentalOps,
12641266
HitObject_TriangleObjectPosition), // returns triangle vertices in
12651267
// object space as <9 x float>
1266-
// ReservedD0 = 0x8000000B, 2147483659U, -2147483637
1267-
EXP_OPCODE(ExperimentalOps, ReservedD0), // reserved
1268+
// LinAlgMatrixMultiplyAccumulate = 0x8000000B, 2147483659U, -2147483637
1269+
EXP_OPCODE(ExperimentalOps,
1270+
LinAlgMatrixMultiplyAccumulate), // Returns the resulting matrix
1271+
// from multiplying A and B and
1272+
// accumulating into C
12681273
// LinAlgFillMatrix = 0x8000000C, 2147483660U, -2147483636
12691274
EXP_OPCODE(ExperimentalOps,
12701275
LinAlgFillMatrix), // fills a matrix with a scalar value
@@ -1316,10 +1321,10 @@ enum class OpCode : unsigned {
13161321
LinAlgMatrixQueryAccumulatorLayout), // returns comptime 0 when
13171322
// accumulator matrix are A
13181323
// layout, 1 when B layout
1319-
// LinAlgMatrixMulOp = 0x80000017, 2147483671U, -2147483625
1324+
// LinAlgMatrixMultiply = 0x80000017, 2147483671U, -2147483625
13201325
EXP_OPCODE(ExperimentalOps,
1321-
LinAlgMatrixMulOp), // applies a multiplication op to matrix C
1322-
// using A and B as parameters
1326+
LinAlgMatrixMultiply), // Returns the resulting matrix from
1327+
// multiplying A and B
13231328
// LinAlgMatrixAccumulate = 0x80000018, 2147483672U, -2147483624
13241329
EXP_OPCODE(ExperimentalOps,
13251330
LinAlgMatrixAccumulate), // accumulate A or B matrix into
@@ -1529,7 +1534,8 @@ enum class OpCodeClass : unsigned {
15291534
LinAlgMatrixLength,
15301535
LinAlgMatrixLoadFromDescriptor,
15311536
LinAlgMatrixLoadFromMemory,
1532-
LinAlgMatrixMulOp,
1537+
LinAlgMatrixMultiply,
1538+
LinAlgMatrixMultiplyAccumulate,
15331539
LinAlgMatrixOuterProduct,
15341540
LinAlgMatrixQueryAccumulatorLayout,
15351541
LinAlgMatrixSetElement,
@@ -1725,7 +1731,7 @@ enum class OpCodeClass : unsigned {
17251731
NodeOutputIsValid,
17261732
OutputComplete,
17271733

1728-
NumOpClasses = 224, // exclusive last value of enumeration
1734+
NumOpClasses = 225, // exclusive last value of enumeration
17291735
};
17301736
// OPCODECLASS-ENUM:END
17311737

include/dxc/DXIL/DxilInstructions.h

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10500,6 +10500,41 @@ struct DxilInst_HitObject_TriangleObjectPosition {
1050010500
void set_hitObject(llvm::Value *val) { Instr->setOperand(1, val); }
1050110501
};
1050210502

10503+
/// This instruction Returns the resulting matrix from multiplying A and B and
10504+
/// accumulating into C
10505+
struct DxilInst_LinAlgMatrixMultiplyAccumulate {
10506+
llvm::Instruction *Instr;
10507+
// Construction and identification
10508+
DxilInst_LinAlgMatrixMultiplyAccumulate(llvm::Instruction *pInstr)
10509+
: Instr(pInstr) {}
10510+
operator bool() const {
10511+
return hlsl::OP::IsDxilOpFuncCallInst(
10512+
Instr, hlsl::OP::OpCode::LinAlgMatrixMultiplyAccumulate);
10513+
}
10514+
// Validation support
10515+
bool isAllowed() const { return true; }
10516+
bool isArgumentListValid() const {
10517+
if (4 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands())
10518+
return false;
10519+
return true;
10520+
}
10521+
// Metadata
10522+
bool requiresUniformInputs() const { return false; }
10523+
// Operand indexes
10524+
enum OperandIdx {
10525+
arg_matrixA = 1,
10526+
arg_matrixB = 2,
10527+
arg_matrixC = 3,
10528+
};
10529+
// Accessors
10530+
llvm::Value *get_matrixA() const { return Instr->getOperand(1); }
10531+
void set_matrixA(llvm::Value *val) { Instr->setOperand(1, val); }
10532+
llvm::Value *get_matrixB() const { return Instr->getOperand(2); }
10533+
void set_matrixB(llvm::Value *val) { Instr->setOperand(2, val); }
10534+
llvm::Value *get_matrixC() const { return Instr->getOperand(3); }
10535+
void set_matrixC(llvm::Value *val) { Instr->setOperand(3, val); }
10536+
};
10537+
1050310538
/// This instruction fills a matrix with a scalar value
1050410539
struct DxilInst_LinAlgFillMatrix {
1050510540
llvm::Instruction *Instr;
@@ -10859,15 +10894,14 @@ struct DxilInst_LinAlgMatrixQueryAccumulatorLayout {
1085910894
bool requiresUniformInputs() const { return false; }
1086010895
};
1086110896

10862-
/// This instruction applies a multiplication op to matrix C using A and B as
10863-
/// parameters
10864-
struct DxilInst_LinAlgMatrixMulOp {
10897+
/// This instruction Returns the resulting matrix from multiplying A and B
10898+
struct DxilInst_LinAlgMatrixMultiply {
1086510899
llvm::Instruction *Instr;
1086610900
// Construction and identification
10867-
DxilInst_LinAlgMatrixMulOp(llvm::Instruction *pInstr) : Instr(pInstr) {}
10901+
DxilInst_LinAlgMatrixMultiply(llvm::Instruction *pInstr) : Instr(pInstr) {}
1086810902
operator bool() const {
10869-
return hlsl::OP::IsDxilOpFuncCallInst(Instr,
10870-
hlsl::OP::OpCode::LinAlgMatrixMulOp);
10903+
return hlsl::OP::IsDxilOpFuncCallInst(
10904+
Instr, hlsl::OP::OpCode::LinAlgMatrixMultiply);
1087110905
}
1087210906
// Validation support
1087310907
bool isAllowed() const { return true; }

lib/DXIL/DxilOperations.cpp

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2823,16 +2823,15 @@ static const OP::OpCodeProperty ExperimentalOps_OpCodeProps[] = {
28232823
{{0x2}},
28242824
{{0x0}}}, // Overloads: f
28252825

2826-
{OC::ReservedD0,
2827-
"ReservedD0",
2828-
OCC::Reserved,
2829-
"reserved",
2830-
Attribute::None,
2831-
0,
2832-
{},
2833-
{}}, // Overloads: v
2834-
28352826
// Linear Algebra Operations
2827+
{OC::LinAlgMatrixMultiplyAccumulate,
2828+
"LinAlgMatrixMultiplyAccumulate",
2829+
OCC::LinAlgMatrixMultiplyAccumulate,
2830+
"linAlgMatrixMultiplyAccumulate",
2831+
Attribute::None,
2832+
4,
2833+
{{0x200}, {0x200}, {0x200}, {0x200}},
2834+
{{0x0}, {0x0}, {0x0}, {0x0}}}, // Overloads: o,o,o,o
28362835
{OC::LinAlgFillMatrix,
28372836
"LinAlgFillMatrix",
28382837
OCC::LinAlgFillMatrix,
@@ -2921,10 +2920,10 @@ static const OP::OpCodeProperty ExperimentalOps_OpCodeProps[] = {
29212920
0,
29222921
{},
29232922
{}}, // Overloads: v
2924-
{OC::LinAlgMatrixMulOp,
2925-
"LinAlgMatrixMulOp",
2926-
OCC::LinAlgMatrixMulOp,
2927-
"linAlgMatrixMulOp",
2923+
{OC::LinAlgMatrixMultiply,
2924+
"LinAlgMatrixMultiply",
2925+
OCC::LinAlgMatrixMultiply,
2926+
"linAlgMatrixMultiply",
29282927
Attribute::None,
29292928
3,
29302929
{{0x200}, {0x200}, {0x200}},
@@ -3950,15 +3949,16 @@ void OP::GetMinShaderModelAndMask(OpCode C, bool bWithTranslation,
39503949
minor = 10;
39513950
return;
39523951
}
3953-
// Instructions: LinAlgFillMatrix=2147483660,
3954-
// LinAlgCopyConvertMatrix=2147483661, LinAlgMatrixLoadFromMemory=2147483663,
3955-
// LinAlgMatrixLength=2147483664, LinAlgMatrixGetCoordinate=2147483665,
3956-
// LinAlgMatrixGetElement=2147483666, LinAlgMatrixSetElement=2147483667,
3952+
// Instructions: LinAlgMatrixMultiplyAccumulate=2147483659,
3953+
// LinAlgFillMatrix=2147483660, LinAlgCopyConvertMatrix=2147483661,
3954+
// LinAlgMatrixLoadFromMemory=2147483663, LinAlgMatrixLength=2147483664,
3955+
// LinAlgMatrixGetCoordinate=2147483665, LinAlgMatrixGetElement=2147483666,
3956+
// LinAlgMatrixSetElement=2147483667,
39573957
// LinAlgMatrixStoreToDescriptor=2147483668,
3958-
// LinAlgMatrixStoreToMemory=2147483669, LinAlgMatrixMulOp=2147483671,
3958+
// LinAlgMatrixStoreToMemory=2147483669, LinAlgMatrixMultiply=2147483671,
39593959
// LinAlgMatrixAccumulate=2147483672,
39603960
// LinAlgMatrixAccumulateToMemory=2147483676
3961-
if ((2147483660 <= op && op <= 2147483661) ||
3961+
if ((2147483659 <= op && op <= 2147483661) ||
39623962
(2147483663 <= op && op <= 2147483669) ||
39633963
(2147483671 <= op && op <= 2147483672) || op == 2147483676) {
39643964
major = 6;
@@ -6557,13 +6557,14 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
65576557
A(pHit);
65586558
break;
65596559

6560-
//
6561-
case OpCode::ReservedD0:
6562-
A(pV);
6560+
// Linear Algebra Operations
6561+
case OpCode::LinAlgMatrixMultiplyAccumulate:
6562+
EXT(0);
65636563
A(pI32);
6564+
EXT(1);
6565+
EXT(2);
6566+
EXT(3);
65646567
break;
6565-
6566-
// Linear Algebra Operations
65676568
case OpCode::LinAlgFillMatrix:
65686569
EXT(0);
65696570
A(pI32);
@@ -6637,7 +6638,7 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
66376638
A(pI32);
66386639
A(pI32);
66396640
break;
6640-
case OpCode::LinAlgMatrixMulOp:
6641+
case OpCode::LinAlgMatrixMultiply:
66416642
EXT(0);
66426643
A(pI32);
66436644
EXT(1);
@@ -7013,7 +7014,6 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
70137014
case OpCode::GetGroupWaveIndex:
70147015
case OpCode::GetGroupWaveCount:
70157016
case OpCode::ClusterID:
7016-
case OpCode::ReservedD0:
70177017
case OpCode::LinAlgMatrixQueryAccumulatorLayout:
70187018
case OpCode::ReservedD1:
70197019
case OpCode::ReservedD2:
@@ -7070,13 +7070,20 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
70707070
return llvm::StructType::get(Ctx,
70717071
{FT->getParamType(1), FT->getParamType(2)});
70727072

7073+
case OpCode::LinAlgMatrixMultiplyAccumulate:
7074+
if (FT->getNumParams() < 4)
7075+
return nullptr;
7076+
return llvm::StructType::get(Ctx,
7077+
{FT->getReturnType(), FT->getParamType(1),
7078+
FT->getParamType(2), FT->getParamType(3)});
7079+
70737080
case OpCode::LinAlgMatrixSetElement:
70747081
if (FT->getNumParams() < 4)
70757082
return nullptr;
70767083
return llvm::StructType::get(
70777084
Ctx, {FT->getReturnType(), FT->getParamType(1), FT->getParamType(3)});
70787085

7079-
case OpCode::LinAlgMatrixMulOp:
7086+
case OpCode::LinAlgMatrixMultiply:
70807087
case OpCode::LinAlgMatrixAccumulate:
70817088
case OpCode::LinAlgMatVecMul:
70827089
case OpCode::LinAlgMatrixOuterProduct:

lib/HLSL/HLOperationLower.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7680,9 +7680,9 @@ constexpr IntrinsicLower gLowerTable[] = {
76807680
{IntrinsicOp::IOP___builtin_LinAlg_MatrixAccumulate, EmptyLower,
76817681
DXIL::OpCode::LinAlgMatrixAccumulate},
76827682
{IntrinsicOp::IOP___builtin_LinAlg_MatrixMatrixMultiply, EmptyLower,
7683-
DXIL::OpCode::LinAlgMatrixMulOp},
7683+
DXIL::OpCode::LinAlgMatrixMultiply},
76847684
{IntrinsicOp::IOP___builtin_LinAlg_MatrixMatrixMultiplyAccumulate,
7685-
EmptyLower, DXIL::OpCode::LinAlgMatrixMulOp},
7685+
EmptyLower, DXIL::OpCode::LinAlgMatrixMultiplyAccumulate},
76867686
{IntrinsicOp::IOP___builtin_LinAlg_MatrixQueryAccumulatorLayout, EmptyLower,
76877687
DXIL::OpCode::LinAlgMatrixQueryAccumulatorLayout},
76887688
{IntrinsicOp::IOP___builtin_LinAlg_MatrixAccumulateToDescriptor, EmptyLower,

utils/hct/gen_intrin_main.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ void [[min_sm=6.10]] __builtin_LinAlg_MatrixStoreToDescriptor(in LinAlgMatrix ma
411411
void [[min_sm=6.10]] __builtin_LinAlg_MatrixStoreToMemory(in LinAlgMatrix matrix, in int GroupSharedMem, in uint offset, in uint stride, in uint layout);
412412
uint [[min_sm=6.10]] __builtin_LinAlg_MatrixQueryAccumulatorLayout();
413413
void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiply(out LinAlgMatrix matrixC, in LinAlgMatrix matrixA, in LinAlgMatrix matrixB);
414-
void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiplyAccumulate(out LinAlgMatrix matrixC, in LinAlgMatrix matrixA, in LinAlgMatrix matrixB);
414+
void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiplyAccumulate(out LinAlgMatrix matrixR, in LinAlgMatrix matrixA, in LinAlgMatrix matrixB, in LinAlgMatrix matrixC);
415415
void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulate(out LinAlgMatrix matrixC, in LinAlgMatrix matrixLHS, in LinAlgMatrix matrixRHS);
416416
void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiply(out numeric<> ret, in LinAlgMatrix mat, in numeric<> input, in uint input_interp);
417417
void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiplyAdd(out numeric<> ret, in LinAlgMatrix mat, in numeric<> input, in uint input_interp, in numeric<> bias, in uint bias_interp);

utils/hct/hctdb.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,8 +1174,8 @@ def populate_categories_and_models_ExperimentalOps(self):
11741174
+ "LinAlgMatrixGetCoordinate,LinAlgMatrixGetElement,"
11751175
+ "LinAlgMatrixSetElement,LinAlgMatrixStoreToDescriptor,"
11761176
+ "LinAlgMatrixLoadFromMemory,LinAlgMatrixStoreToMemory,"
1177-
+ "LinAlgMatrixAccumulateToMemory,LinAlgMatrixMulOp,"
1178-
+ "LinAlgMatrixAccumulate"
1177+
+ "LinAlgMatrixAccumulateToMemory,LinAlgMatrixMultiply,"
1178+
+ "LinAlgMatrixMultiplyAccumulate,LinAlgMatrixAccumulate"
11791179
):
11801180
i.category = "Linear Algebra Operations"
11811181
i.shader_model = experimental_sm
@@ -6341,7 +6341,19 @@ def populate_ExperimentalOps(self):
63416341
)
63426342

63436343
# Linear Algebra Ops
6344-
op_table.reserve_dxil_op_range("ReservedD", 1)
6344+
add_dxil_op(
6345+
"LinAlgMatrixMultiplyAccumulate",
6346+
"LinAlgMatrixMultiplyAccumulate",
6347+
"Returns the resulting matrix from multiplying A and B and accumulating into C",
6348+
"o,o,o,o",
6349+
"",
6350+
[
6351+
db_dxil_param(0, "$x0", "", "resulting matrix"),
6352+
db_dxil_param(2, "$x1", "matrixA", "A matrix"),
6353+
db_dxil_param(3, "$x2", "matrixB", "B matrix"),
6354+
db_dxil_param(4, "$x3", "matrixC", "C matrix"),
6355+
],
6356+
)
63456357

63466358
add_dxil_op(
63476359
"LinAlgFillMatrix",
@@ -6530,9 +6542,9 @@ def populate_ExperimentalOps(self):
65306542
)
65316543

65326544
add_dxil_op(
6533-
"LinAlgMatrixMulOp",
6534-
"LinAlgMatrixMulOp",
6535-
"applies a multiplication op to matrix C using A and B as parameters",
6545+
"LinAlgMatrixMultiply",
6546+
"LinAlgMatrixMultiply",
6547+
"Returns the resulting matrix from multiplying A and B",
65366548
"o,o,o",
65376549
"",
65386550
[

0 commit comments

Comments
 (0)