Skip to content

Commit 0f17d5a

Browse files
authored
[SM6.10] Implement LinAlg Convert Builtin (#8308)
Implements the LinAlg Convert Builtin Fixes #8288
1 parent c689385 commit 0f17d5a

23 files changed

Lines changed: 226 additions & 41 deletions

docs/DXIL.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3095,9 +3095,9 @@ ID Name Description
30953095
2147483675 LinAlgMatrixAccumulateToDescriptor accumulates a matrix to a RWByteAddressBuffer
30963096
2147483676 LinAlgMatrixAccumulateToMemory accumulates a matrix to groupshared memory
30973097
2147483677 LinAlgMatrixOuterProduct Outer products an M sized vector and a N sized vector producing an MxN matrix
3098-
2147483678 ReservedE1 reserved
3099-
2147483679 ReservedE2 reserved
3100-
2147483680 ReservedE3 reserved
3098+
2147483678 LinAlgConvert Convert vector components from one interpretation to another
3099+
2147483679 ReservedE0 reserved
3100+
2147483680 ReservedE1 reserved
31013101
2147483681 DebugBreak triggers a breakpoint if a debugger is attached
31023102
2147483682 IsDebuggerPresent returns true if a debugger is attached
31033103
========== ======================================== ===================================================================================================================

include/dxc/DXIL/DxilConstants.h

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -531,9 +531,8 @@ static const OpCodeTableID TableID = OpCodeTableID::ExperimentalOps;
531531
// Enumeration for ExperimentalOps DXIL operations
532532
enum class OpCode : unsigned {
533533
//
534-
ReservedE1 = 30, // reserved
535-
ReservedE2 = 31, // reserved
536-
ReservedE3 = 32, // reserved
534+
ReservedE0 = 31, // reserved
535+
ReservedE1 = 32, // reserved
537536

538537
// Debugging
539538
DebugBreak = 33, // triggers a breakpoint if a debugger is attached
@@ -552,6 +551,8 @@ enum class OpCode : unsigned {
552551
9, // returns committed triangle vertices in object space as <9 x float>
553552

554553
// Linear Algebra Operations
554+
LinAlgConvert =
555+
30, // Convert vector components from one interpretation to another
555556
LinAlgCopyConvertMatrix =
556557
13, // Converts and copies the element and use type of the source matrix
557558
// to the destination matrix with optional transpose
@@ -1349,12 +1350,13 @@ enum class OpCode : unsigned {
13491350
ExperimentalOps,
13501351
LinAlgMatrixOuterProduct), // Outer products an M sized vector and a N
13511352
// sized vector producing an MxN matrix
1352-
// ReservedE1 = 0x8000001E, 2147483678U, -2147483618
1353+
// LinAlgConvert = 0x8000001E, 2147483678U, -2147483618
1354+
EXP_OPCODE(ExperimentalOps, LinAlgConvert), // Convert vector components from
1355+
// one interpretation to another
1356+
// ReservedE0 = 0x8000001F, 2147483679U, -2147483617
1357+
EXP_OPCODE(ExperimentalOps, ReservedE0), // reserved
1358+
// ReservedE1 = 0x80000020, 2147483680U, -2147483616
13531359
EXP_OPCODE(ExperimentalOps, ReservedE1), // reserved
1354-
// ReservedE2 = 0x8000001F, 2147483679U, -2147483617
1355-
EXP_OPCODE(ExperimentalOps, ReservedE2), // reserved
1356-
// ReservedE3 = 0x80000020, 2147483680U, -2147483616
1357-
EXP_OPCODE(ExperimentalOps, ReservedE3), // reserved
13581360
// DebugBreak = 0x80000021, 2147483681U, -2147483615
13591361
EXP_OPCODE(ExperimentalOps,
13601362
DebugBreak), // triggers a breakpoint if a debugger is attached
@@ -1520,6 +1522,7 @@ enum class OpCodeClass : unsigned {
15201522
CreateHandleForLib,
15211523

15221524
// Linear Algebra Operations
1525+
LinAlgConvert,
15231526
LinAlgCopyConvertMatrix,
15241527
LinAlgFillMatrix,
15251528
LinAlgMatVecMul,
@@ -1725,7 +1728,7 @@ enum class OpCodeClass : unsigned {
17251728
NodeOutputIsValid,
17261729
OutputComplete,
17271730

1728-
NumOpClasses = 221, // exclusive last value of enumeration
1731+
NumOpClasses = 222, // exclusive last value of enumeration
17291732
};
17301733
// OPCODECLASS-ENUM:END
17311734

include/dxc/DXIL/DxilInstructions.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10920,6 +10920,40 @@ struct DxilInst_LinAlgMatrixOuterProduct {
1092010920
void set_vectorB(llvm::Value *val) { Instr->setOperand(2, val); }
1092110921
};
1092210922

10923+
/// This instruction Convert vector components from one interpretation to
10924+
/// another
10925+
struct DxilInst_LinAlgConvert {
10926+
llvm::Instruction *Instr;
10927+
// Construction and identification
10928+
DxilInst_LinAlgConvert(llvm::Instruction *pInstr) : Instr(pInstr) {}
10929+
operator bool() const {
10930+
return hlsl::OP::IsDxilOpFuncCallInst(Instr,
10931+
hlsl::OP::OpCode::LinAlgConvert);
10932+
}
10933+
// Validation support
10934+
bool isAllowed() const { return true; }
10935+
bool isArgumentListValid() const {
10936+
if (4 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands())
10937+
return false;
10938+
return true;
10939+
}
10940+
// Metadata
10941+
bool requiresUniformInputs() const { return false; }
10942+
// Operand indexes
10943+
enum OperandIdx {
10944+
arg_inputVector = 1,
10945+
arg_inputInterpretation = 2,
10946+
arg_outputInterpretation = 3,
10947+
};
10948+
// Accessors
10949+
llvm::Value *get_inputVector() const { return Instr->getOperand(1); }
10950+
void set_inputVector(llvm::Value *val) { Instr->setOperand(1, val); }
10951+
llvm::Value *get_inputInterpretation() const { return Instr->getOperand(2); }
10952+
void set_inputInterpretation(llvm::Value *val) { Instr->setOperand(2, val); }
10953+
llvm::Value *get_outputInterpretation() const { return Instr->getOperand(3); }
10954+
void set_outputInterpretation(llvm::Value *val) { Instr->setOperand(3, val); }
10955+
};
10956+
1092310957
/// This instruction triggers a breakpoint if a debugger is attached
1092410958
struct DxilInst_DebugBreak {
1092510959
llvm::Instruction *Instr;

include/dxc/HlslIntrinsicOp.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ enum class IntrinsicOp {
112112
IOP_WorldToObject = 99,
113113
IOP_WorldToObject3x4 = 100,
114114
IOP_WorldToObject4x3 = 101,
115+
IOP___builtin_LinAlg_Convert = 422,
115116
IOP___builtin_LinAlg_CopyConvertMatrix = 401,
116117
IOP___builtin_LinAlg_FillMatrix = 402,
117118
IOP___builtin_LinAlg_MatrixAccumulate = 411,
@@ -428,7 +429,7 @@ enum class IntrinsicOp {
428429
IOP_usign = 355,
429430
MOP_InterlockedUMax = 356,
430431
MOP_InterlockedUMin = 357,
431-
Num_Intrinsics = 422,
432+
Num_Intrinsics = 423,
432433
};
433434
inline bool HasUnsignedIntrinsicOpcode(IntrinsicOp opcode) {
434435
switch (opcode) {

lib/DXIL/DxilOperations.cpp

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2976,25 +2976,25 @@ static const OP::OpCodeProperty ExperimentalOps_OpCodeProps[] = {
29762976
3,
29772977
{{0x200}, {0x400}, {0x400}},
29782978
{{0x0}, {0x63}, {0x63}}}, // Overloads: o,<hfwi,<hfwi
2979-
2980-
{OC::ReservedE1,
2981-
"ReservedE1",
2982-
OCC::Reserved,
2983-
"reserved",
2979+
{OC::LinAlgConvert,
2980+
"LinAlgConvert",
2981+
OCC::LinAlgConvert,
2982+
"linAlgConvert",
29842983
Attribute::None,
2985-
0,
2986-
{},
2987-
{}}, // Overloads: v
2988-
{OC::ReservedE2,
2989-
"ReservedE2",
2984+
2,
2985+
{{0x400}, {0x400}},
2986+
{{0x63}, {0x63}}}, // Overloads: <hfwi,<hfwi
2987+
2988+
{OC::ReservedE0,
2989+
"ReservedE0",
29902990
OCC::Reserved,
29912991
"reserved",
29922992
Attribute::None,
29932993
0,
29942994
{},
29952995
{}}, // Overloads: v
2996-
{OC::ReservedE3,
2997-
"ReservedE3",
2996+
{OC::ReservedE1,
2997+
"ReservedE1",
29982998
OCC::Reserved,
29992999
"reserved",
30003000
Attribute::None,
@@ -3955,12 +3955,13 @@ void OP::GetMinShaderModelAndMask(OpCode C, bool bWithTranslation,
39553955
// LinAlgMatrixQueryAccumulatorLayout=2147483670, LinAlgMatVecMul=2147483673,
39563956
// LinAlgMatVecMulAdd=2147483674,
39573957
// LinAlgMatrixAccumulateToDescriptor=2147483675,
3958-
// LinAlgMatrixOuterProduct=2147483677, DebugBreak=2147483681,
3959-
// IsDebuggerPresent=2147483682
3958+
// LinAlgMatrixOuterProduct=2147483677, LinAlgConvert=2147483678,
3959+
// DebugBreak=2147483681, IsDebuggerPresent=2147483682
39603960
if (op == 2147483648 || (2147483652 <= op && op <= 2147483653) ||
39613961
(2147483656 <= op && op <= 2147483657) || op == 2147483662 ||
39623962
op == 2147483670 || (2147483673 <= op && op <= 2147483675) ||
3963-
op == 2147483677 || (2147483681 <= op && op <= 2147483682)) {
3963+
(2147483677 <= op && op <= 2147483678) ||
3964+
(2147483681 <= op && op <= 2147483682)) {
39643965
major = 6;
39653966
minor = 10;
39663967
return;
@@ -6673,17 +6674,20 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
66736674
A(EXT(1));
66746675
A(EXT(2));
66756676
break;
6676-
6677-
//
6678-
case OpCode::ReservedE1:
6679-
A(pV);
6677+
case OpCode::LinAlgConvert:
6678+
A(EXT(0));
6679+
A(pI32);
6680+
A(EXT(1));
6681+
A(pI32);
66806682
A(pI32);
66816683
break;
6682-
case OpCode::ReservedE2:
6684+
6685+
//
6686+
case OpCode::ReservedE0:
66836687
A(pV);
66846688
A(pI32);
66856689
break;
6686-
case OpCode::ReservedE3:
6690+
case OpCode::ReservedE1:
66876691
A(pV);
66886692
A(pI32);
66896693
break;
@@ -7002,9 +7006,8 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
70027006
case OpCode::GetGroupWaveCount:
70037007
case OpCode::ClusterID:
70047008
case OpCode::LinAlgMatrixQueryAccumulatorLayout:
7009+
case OpCode::ReservedE0:
70057010
case OpCode::ReservedE1:
7006-
case OpCode::ReservedE2:
7007-
case OpCode::ReservedE3:
70087011
case OpCode::DebugBreak:
70097012
case OpCode::IsDebuggerPresent:
70107013
return Type::getVoidTy(Ctx);
@@ -7047,6 +7050,7 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
70477050
case OpCode::LinAlgFillMatrix:
70487051
case OpCode::LinAlgCopyConvertMatrix:
70497052
case OpCode::LinAlgMatrixGetElement:
7053+
case OpCode::LinAlgConvert:
70507054
if (FT->getNumParams() < 2)
70517055
return nullptr;
70527056
return llvm::StructType::get(Ctx,

lib/HLSL/HLOperationLower.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7089,6 +7089,30 @@ Value *TranslateLinAlgMatrixAccumStoreToMemory(
70897089
{OpArg, Matrix, ArrPtr, Offset, Stride, Layout});
70907090
}
70917091

7092+
Value *TranslateLinAlgConvert(CallInst *CI, IntrinsicOp IOP, OP::OpCode OpCode,
7093+
HLOperationLowerHelper &Helper,
7094+
HLObjectOperationLowerHelper *ObjHelper,
7095+
bool &Translated) {
7096+
hlsl::OP *HlslOp = &Helper.hlslOP;
7097+
IRBuilder<> Builder(CI);
7098+
7099+
Value *OutVecPtr = CI->getArgOperand(1);
7100+
DXASSERT_NOMSG(isa<PointerType>(OutVecPtr->getType()));
7101+
Type *OutVecTy = OutVecPtr->getType()->getPointerElementType();
7102+
Value *InVec = CI->getArgOperand(2);
7103+
Value *InInterp = CI->getArgOperand(3);
7104+
Value *OutInterp = CI->getArgOperand(4);
7105+
7106+
Constant *OpArg = HlslOp->GetU32Const((unsigned)OpCode);
7107+
Function *DxilFunc = HlslOp->GetOpFunc(OpCode, {OutVecTy, InVec->getType()});
7108+
7109+
Value *OutVec =
7110+
Builder.CreateCall(DxilFunc, {OpArg, InVec, InInterp, OutInterp});
7111+
Builder.CreateStore(OutVec, OutVecPtr);
7112+
7113+
return nullptr;
7114+
}
7115+
70927116
} // namespace
70937117

70947118
// Lower table.
@@ -7880,6 +7904,9 @@ constexpr IntrinsicLower gLowerTable[] = {
78807904
DXIL::OpCode::DebugBreak},
78817905
{IntrinsicOp::IOP_DxIsDebuggerPresent, TranslateWaveToVal,
78827906
DXIL::OpCode::IsDebuggerPresent},
7907+
7908+
{IntrinsicOp::IOP___builtin_LinAlg_Convert, TranslateLinAlgConvert,
7909+
DXIL::OpCode::LinAlgConvert},
78837910
};
78847911
constexpr size_t NumLowerTableEntries =
78857912
sizeof(gLowerTable) / sizeof(gLowerTable[0]);
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T cs_6_10 -HV 202x -E main %s | FileCheck %s
3+
// RUN: %dxc -T cs_6_10 -HV 202x -E main -fcgl %s | FileCheck %s --check-prefix=CHECK2
4+
5+
[numthreads(4,1,1)]
6+
void main() {
7+
// CHECK-LABEL: define void @main()
8+
9+
// CHECK: %{{.*}} = call <4 x i32> @dx.op.linAlgConvert.v4i32.v4f32
10+
// CHECK-SAME: (i32 -2147483618, <4 x float> <float 9.000000e+00, float 8.000000e+00, float 7.000000e+00, float 6.000000e+00>, i32 1, i32 2)
11+
// CHECK-SAME: ; LinAlgConvert(inputVector,inputInterpretation,outputInterpretation)
12+
13+
// CHECK2: call void @"dx.hl.op..void (i32, <4 x i32>*, <4 x float>, i32, i32)"
14+
// CHECK2-SAME: (i32 422, <4 x i32>* %result, <4 x float> %{{.*}}, i32 1, i32 2)
15+
float4 vec = {9.0, 8.0, 7.0, 6.0};
16+
int4 result;
17+
__builtin_LinAlg_Convert(result, vec, 1, 2);
18+
}

tools/clang/test/LitDXILValidation/LinAlgMatrix/linalgmatrix-as.ll

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ define void @mainAS() {
5050
; dx.op.linAlgMatVecMulAdd
5151
%v7 = call <4 x i32> @dx.op.linAlgMatVecMulAdd.v4i32.mC4M5N4U0S2.v4i32.v4i32(i32 -2147483622, %dx.types.LinAlgMatrixC4M5N4U0S2 %v4, <4 x i32> <i32 9, i32 9, i32 9, i32 9>, i32 2, <4 x i32> <i32 7, i32 7, i32 7, i32 7>, i32 3) ; LinAlgMatVecMulAdd(matrix,inputVector,inputInterpretation,biasVector,biasInterpretation)
5252

53+
; dx.op.linAlgConvert
54+
%v16 = call <4 x float> @dx.op.linAlgConvert.v4f32.v4i32(i32 -2147483618, <4 x i32> zeroinitializer, i32 1, i32 2) ; LinAlgConvert(inputVector,inputInterpretation,outputInterpretation)
55+
5356
;
5457
; Built-ins restricted to compute, mesh and amplification shaders
5558
;
@@ -123,6 +126,9 @@ declare <4 x i32> @dx.op.linAlgMatVecMul.v4i32.mC4M5N4U0S2.v4i32(i32, %dx.types.
123126
; Function Attrs: nounwind
124127
declare <4 x i32> @dx.op.linAlgMatVecMulAdd.v4i32.mC4M5N4U0S2.v4i32.v4i32(i32, %dx.types.LinAlgMatrixC4M5N4U0S2, <4 x i32>, i32, <4 x i32>, i32) #0
125128

129+
; Function Attrs: nounwind
130+
declare <4 x float> @dx.op.linAlgConvert.v4f32.v4i32(i32, <4 x i32>, i32, i32) #0
131+
126132
; Function Attrs: nounwind
127133
declare %dx.types.LinAlgMatrixC4M4N5U1S2 @dx.op.linAlgCopyConvertMatrix.mC4M4N5U1S2.mC4M5N4U0S2(i32, %dx.types.LinAlgMatrixC4M5N4U0S2, i1) #0
128134

tools/clang/test/LitDXILValidation/LinAlgMatrix/linalgmatrix-cs.ll

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ define void @mainCS() {
4949
; dx.op.linAlgMatVecMulAdd
5050
%v7 = call <4 x i32> @dx.op.linAlgMatVecMulAdd.v4i32.mC4M5N4U0S2.v4i32.v4i32(i32 -2147483622, %dx.types.LinAlgMatrixC4M5N4U0S2 %v4, <4 x i32> <i32 9, i32 9, i32 9, i32 9>, i32 2, <4 x i32> <i32 7, i32 7, i32 7, i32 7>, i32 3) ; LinAlgMatVecMulAdd(matrix,inputVector,inputInterpretation,biasVector,biasInterpretation)
5151

52+
; dx.op.linAlgConvert
53+
%v16 = call <4 x float> @dx.op.linAlgConvert.v4f32.v4i32(i32 -2147483618, <4 x i32> zeroinitializer, i32 1, i32 2) ; LinAlgConvert(inputVector,inputInterpretation,outputInterpretation)
54+
5255
;
5356
; Built-ins restricted to compute, mesh and amplification shaders
5457
;
@@ -119,6 +122,9 @@ declare <4 x i32> @dx.op.linAlgMatVecMul.v4i32.mC4M5N4U0S2.v4i32(i32, %dx.types.
119122
; Function Attrs: nounwind
120123
declare <4 x i32> @dx.op.linAlgMatVecMulAdd.v4i32.mC4M5N4U0S2.v4i32.v4i32(i32, %dx.types.LinAlgMatrixC4M5N4U0S2, <4 x i32>, i32, <4 x i32>, i32) #0
121124

125+
; Function Attrs: nounwind
126+
declare <4 x float> @dx.op.linAlgConvert.v4f32.v4i32(i32, <4 x i32>, i32, i32) #0
127+
122128
; Function Attrs: nounwind
123129
declare %dx.types.LinAlgMatrixC4M4N5U1S2 @dx.op.linAlgCopyConvertMatrix.mC4M4N5U1S2.mC4M5N4U0S2(i32, %dx.types.LinAlgMatrixC4M5N4U0S2, i1) #0
124130

tools/clang/test/LitDXILValidation/LinAlgMatrix/linalgmatrix-ds.ll

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ define void @MainDS() {
6565
; dx.op.linAlgMatVecMulAdd
6666
%v7 = call <4 x i32> @dx.op.linAlgMatVecMulAdd.v4i32.mC4M5N4U0S2.v4i32.v4i32(i32 -2147483622, %dx.types.LinAlgMatrixC4M5N4U0S2 %v4, <4 x i32> <i32 9, i32 9, i32 9, i32 9>, i32 2, <4 x i32> <i32 7, i32 7, i32 7, i32 7>, i32 3) ; LinAlgMatVecMulAdd(matrix,inputVector,inputInterpretation,biasVector,biasInterpretation)
6767

68+
; dx.op.linAlgConvert
69+
%v16 = call <4 x float> @dx.op.linAlgConvert.v4f32.v4i32(i32 -2147483618, <4 x i32> zeroinitializer, i32 1, i32 2) ; LinAlgConvert(inputVector,inputInterpretation,outputInterpretation)
70+
6871
;
6972
; Built-ins restricted to compute, mesh and amplification shaders
7073
;
@@ -143,6 +146,9 @@ declare <4 x i32> @dx.op.linAlgMatVecMul.v4i32.mC4M5N4U0S2.v4i32(i32, %dx.types.
143146
; Function Attrs: nounwind
144147
declare <4 x i32> @dx.op.linAlgMatVecMulAdd.v4i32.mC4M5N4U0S2.v4i32.v4i32(i32, %dx.types.LinAlgMatrixC4M5N4U0S2, <4 x i32>, i32, <4 x i32>, i32) #0
145148

149+
; Function Attrs: nounwind
150+
declare <4 x float> @dx.op.linAlgConvert.v4f32.v4i32(i32, <4 x i32>, i32, i32) #0
151+
146152
; Function Attrs: nounwind
147153
declare %dx.types.LinAlgMatrixC4M4N5U1S2 @dx.op.linAlgCopyConvertMatrix.mC4M4N5U1S2.mC4M5N4U0S2(i32, %dx.types.LinAlgMatrixC4M5N4U0S2, i1) #0
148154

0 commit comments

Comments
 (0)