Skip to content

Commit f8f2d31

Browse files
committed
Merge branch 'main' of https://github.com/microsoft/DirectXShaderCompiler into linalg-matrix-header
2 parents 18f7a75 + 097ab6b commit f8f2d31

25 files changed

Lines changed: 350 additions & 150 deletions

docs/DXIL.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3095,9 +3095,9 @@ ID Name Description
30953095
2147483675 LinAlgMatrixAccumulateToDescriptor accumulates a matrix to a RWByteAddressBuffer
30963096
2147483676 LinAlgMatrixAccumulateToMemory accumulates a matrix to groupshared memory
30973097
2147483677 LinAlgMatrixOuterProduct Outer products an M sized vector and a N sized vector producing an MxN matrix
3098-
2147483678 ReservedE1 reserved
3099-
2147483679 ReservedE2 reserved
3100-
2147483680 ReservedE3 reserved
3098+
2147483678 LinAlgConvert Convert vector components from one interpretation to another
3099+
2147483679 ReservedE0 reserved
3100+
2147483680 ReservedE1 reserved
31013101
2147483681 DebugBreak triggers a breakpoint if a debugger is attached
31023102
2147483682 IsDebuggerPresent returns true if a debugger is attached
31033103
========== ======================================== ===================================================================================================================

include/dxc/DXIL/DxilConstants.h

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -531,9 +531,8 @@ static const OpCodeTableID TableID = OpCodeTableID::ExperimentalOps;
531531
// Enumeration for ExperimentalOps DXIL operations
532532
enum class OpCode : unsigned {
533533
//
534-
ReservedE1 = 30, // reserved
535-
ReservedE2 = 31, // reserved
536-
ReservedE3 = 32, // reserved
534+
ReservedE0 = 31, // reserved
535+
ReservedE1 = 32, // reserved
537536

538537
// Debugging
539538
DebugBreak = 33, // triggers a breakpoint if a debugger is attached
@@ -552,6 +551,8 @@ enum class OpCode : unsigned {
552551
9, // returns committed triangle vertices in object space as <9 x float>
553552

554553
// Linear Algebra Operations
554+
LinAlgConvert =
555+
30, // Convert vector components from one interpretation to another
555556
LinAlgCopyConvertMatrix =
556557
13, // Converts and copies the element and use type of the source matrix
557558
// to the destination matrix with optional transpose
@@ -1349,12 +1350,13 @@ enum class OpCode : unsigned {
13491350
ExperimentalOps,
13501351
LinAlgMatrixOuterProduct), // Outer products an M sized vector and a N
13511352
// sized vector producing an MxN matrix
1352-
// ReservedE1 = 0x8000001E, 2147483678U, -2147483618
1353+
// LinAlgConvert = 0x8000001E, 2147483678U, -2147483618
1354+
EXP_OPCODE(ExperimentalOps, LinAlgConvert), // Convert vector components from
1355+
// one interpretation to another
1356+
// ReservedE0 = 0x8000001F, 2147483679U, -2147483617
1357+
EXP_OPCODE(ExperimentalOps, ReservedE0), // reserved
1358+
// ReservedE1 = 0x80000020, 2147483680U, -2147483616
13531359
EXP_OPCODE(ExperimentalOps, ReservedE1), // reserved
1354-
// ReservedE2 = 0x8000001F, 2147483679U, -2147483617
1355-
EXP_OPCODE(ExperimentalOps, ReservedE2), // reserved
1356-
// ReservedE3 = 0x80000020, 2147483680U, -2147483616
1357-
EXP_OPCODE(ExperimentalOps, ReservedE3), // reserved
13581360
// DebugBreak = 0x80000021, 2147483681U, -2147483615
13591361
EXP_OPCODE(ExperimentalOps,
13601362
DebugBreak), // triggers a breakpoint if a debugger is attached
@@ -1520,6 +1522,7 @@ enum class OpCodeClass : unsigned {
15201522
CreateHandleForLib,
15211523

15221524
// Linear Algebra Operations
1525+
LinAlgConvert,
15231526
LinAlgCopyConvertMatrix,
15241527
LinAlgFillMatrix,
15251528
LinAlgMatVecMul,
@@ -1725,7 +1728,7 @@ enum class OpCodeClass : unsigned {
17251728
NodeOutputIsValid,
17261729
OutputComplete,
17271730

1728-
NumOpClasses = 221, // exclusive last value of enumeration
1731+
NumOpClasses = 222, // exclusive last value of enumeration
17291732
};
17301733
// OPCODECLASS-ENUM:END
17311734

include/dxc/DXIL/DxilInstructions.h

Lines changed: 60 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10744,7 +10744,7 @@ struct DxilInst_LinAlgMatVecMul {
1074410744
// Validation support
1074510745
bool isAllowed() const { return true; }
1074610746
bool isArgumentListValid() const {
10747-
if (4 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands())
10747+
if (5 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands())
1074810748
return false;
1074910749
return true;
1075010750
}
@@ -10753,16 +10753,19 @@ struct DxilInst_LinAlgMatVecMul {
1075310753
// Operand indexes
1075410754
enum OperandIdx {
1075510755
arg_matrix = 1,
10756-
arg_inputVector = 2,
10757-
arg_interpretation = 3,
10756+
arg_isOutputSigned = 2,
10757+
arg_inputVector = 3,
10758+
arg_interpretation = 4,
1075810759
};
1075910760
// Accessors
1076010761
llvm::Value *get_matrix() const { return Instr->getOperand(1); }
1076110762
void set_matrix(llvm::Value *val) { Instr->setOperand(1, val); }
10762-
llvm::Value *get_inputVector() const { return Instr->getOperand(2); }
10763-
void set_inputVector(llvm::Value *val) { Instr->setOperand(2, val); }
10764-
llvm::Value *get_interpretation() const { return Instr->getOperand(3); }
10765-
void set_interpretation(llvm::Value *val) { Instr->setOperand(3, val); }
10763+
llvm::Value *get_isOutputSigned() const { return Instr->getOperand(2); }
10764+
void set_isOutputSigned(llvm::Value *val) { Instr->setOperand(2, val); }
10765+
llvm::Value *get_inputVector() const { return Instr->getOperand(3); }
10766+
void set_inputVector(llvm::Value *val) { Instr->setOperand(3, val); }
10767+
llvm::Value *get_interpretation() const { return Instr->getOperand(4); }
10768+
void set_interpretation(llvm::Value *val) { Instr->setOperand(4, val); }
1076610769
};
1076710770

1076810771
/// This instruction Multiplies a MxK dimension matrix and a K sized input
@@ -10778,7 +10781,7 @@ struct DxilInst_LinAlgMatVecMulAdd {
1077810781
// Validation support
1077910782
bool isAllowed() const { return true; }
1078010783
bool isArgumentListValid() const {
10781-
if (6 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands())
10784+
if (7 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands())
1078210785
return false;
1078310786
return true;
1078410787
}
@@ -10787,22 +10790,25 @@ struct DxilInst_LinAlgMatVecMulAdd {
1078710790
// Operand indexes
1078810791
enum OperandIdx {
1078910792
arg_matrix = 1,
10790-
arg_inputVector = 2,
10791-
arg_inputInterpretation = 3,
10792-
arg_biasVector = 4,
10793-
arg_biasInterpretation = 5,
10793+
arg_isOutputSigned = 2,
10794+
arg_inputVector = 3,
10795+
arg_inputInterpretation = 4,
10796+
arg_biasVector = 5,
10797+
arg_biasInterpretation = 6,
1079410798
};
1079510799
// Accessors
1079610800
llvm::Value *get_matrix() const { return Instr->getOperand(1); }
1079710801
void set_matrix(llvm::Value *val) { Instr->setOperand(1, val); }
10798-
llvm::Value *get_inputVector() const { return Instr->getOperand(2); }
10799-
void set_inputVector(llvm::Value *val) { Instr->setOperand(2, val); }
10800-
llvm::Value *get_inputInterpretation() const { return Instr->getOperand(3); }
10801-
void set_inputInterpretation(llvm::Value *val) { Instr->setOperand(3, val); }
10802-
llvm::Value *get_biasVector() const { return Instr->getOperand(4); }
10803-
void set_biasVector(llvm::Value *val) { Instr->setOperand(4, val); }
10804-
llvm::Value *get_biasInterpretation() const { return Instr->getOperand(5); }
10805-
void set_biasInterpretation(llvm::Value *val) { Instr->setOperand(5, val); }
10802+
llvm::Value *get_isOutputSigned() const { return Instr->getOperand(2); }
10803+
void set_isOutputSigned(llvm::Value *val) { Instr->setOperand(2, val); }
10804+
llvm::Value *get_inputVector() const { return Instr->getOperand(3); }
10805+
void set_inputVector(llvm::Value *val) { Instr->setOperand(3, val); }
10806+
llvm::Value *get_inputInterpretation() const { return Instr->getOperand(4); }
10807+
void set_inputInterpretation(llvm::Value *val) { Instr->setOperand(4, val); }
10808+
llvm::Value *get_biasVector() const { return Instr->getOperand(5); }
10809+
void set_biasVector(llvm::Value *val) { Instr->setOperand(5, val); }
10810+
llvm::Value *get_biasInterpretation() const { return Instr->getOperand(6); }
10811+
void set_biasInterpretation(llvm::Value *val) { Instr->setOperand(6, val); }
1080610812
};
1080710813

1080810814
/// This instruction accumulates a matrix to a RWByteAddressBuffer
@@ -10920,6 +10926,40 @@ struct DxilInst_LinAlgMatrixOuterProduct {
1092010926
void set_vectorB(llvm::Value *val) { Instr->setOperand(2, val); }
1092110927
};
1092210928

10929+
/// This instruction Convert vector components from one interpretation to
10930+
/// another
10931+
struct DxilInst_LinAlgConvert {
10932+
llvm::Instruction *Instr;
10933+
// Construction and identification
10934+
DxilInst_LinAlgConvert(llvm::Instruction *pInstr) : Instr(pInstr) {}
10935+
operator bool() const {
10936+
return hlsl::OP::IsDxilOpFuncCallInst(Instr,
10937+
hlsl::OP::OpCode::LinAlgConvert);
10938+
}
10939+
// Validation support
10940+
bool isAllowed() const { return true; }
10941+
bool isArgumentListValid() const {
10942+
if (4 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands())
10943+
return false;
10944+
return true;
10945+
}
10946+
// Metadata
10947+
bool requiresUniformInputs() const { return false; }
10948+
// Operand indexes
10949+
enum OperandIdx {
10950+
arg_inputVector = 1,
10951+
arg_inputInterpretation = 2,
10952+
arg_outputInterpretation = 3,
10953+
};
10954+
// Accessors
10955+
llvm::Value *get_inputVector() const { return Instr->getOperand(1); }
10956+
void set_inputVector(llvm::Value *val) { Instr->setOperand(1, val); }
10957+
llvm::Value *get_inputInterpretation() const { return Instr->getOperand(2); }
10958+
void set_inputInterpretation(llvm::Value *val) { Instr->setOperand(2, val); }
10959+
llvm::Value *get_outputInterpretation() const { return Instr->getOperand(3); }
10960+
void set_outputInterpretation(llvm::Value *val) { Instr->setOperand(3, val); }
10961+
};
10962+
1092310963
/// This instruction triggers a breakpoint if a debugger is attached
1092410964
struct DxilInst_DebugBreak {
1092510965
llvm::Instruction *Instr;

include/dxc/HlslIntrinsicOp.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ enum class IntrinsicOp {
112112
IOP_WorldToObject = 99,
113113
IOP_WorldToObject3x4 = 100,
114114
IOP_WorldToObject4x3 = 101,
115+
IOP___builtin_LinAlg_Convert = 422,
115116
IOP___builtin_LinAlg_CopyConvertMatrix = 401,
116117
IOP___builtin_LinAlg_FillMatrix = 402,
117118
IOP___builtin_LinAlg_MatrixAccumulate = 411,
@@ -428,7 +429,7 @@ enum class IntrinsicOp {
428429
IOP_usign = 355,
429430
MOP_InterlockedUMax = 356,
430431
MOP_InterlockedUMin = 357,
431-
Num_Intrinsics = 422,
432+
Num_Intrinsics = 423,
432433
};
433434
inline bool HasUnsignedIntrinsicOpcode(IntrinsicOp opcode) {
434435
switch (opcode) {

lib/DXIL/DxilOperations.cpp

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2976,25 +2976,25 @@ static const OP::OpCodeProperty ExperimentalOps_OpCodeProps[] = {
29762976
3,
29772977
{{0x200}, {0x400}, {0x400}},
29782978
{{0x0}, {0x63}, {0x63}}}, // Overloads: o,<hfwi,<hfwi
2979-
2980-
{OC::ReservedE1,
2981-
"ReservedE1",
2982-
OCC::Reserved,
2983-
"reserved",
2979+
{OC::LinAlgConvert,
2980+
"LinAlgConvert",
2981+
OCC::LinAlgConvert,
2982+
"linAlgConvert",
29842983
Attribute::None,
2985-
0,
2986-
{},
2987-
{}}, // Overloads: v
2988-
{OC::ReservedE2,
2989-
"ReservedE2",
2984+
2,
2985+
{{0x400}, {0x400}},
2986+
{{0x63}, {0x63}}}, // Overloads: <hfwi,<hfwi
2987+
2988+
{OC::ReservedE0,
2989+
"ReservedE0",
29902990
OCC::Reserved,
29912991
"reserved",
29922992
Attribute::None,
29932993
0,
29942994
{},
29952995
{}}, // Overloads: v
2996-
{OC::ReservedE3,
2997-
"ReservedE3",
2996+
{OC::ReservedE1,
2997+
"ReservedE1",
29982998
OCC::Reserved,
29992999
"reserved",
30003000
Attribute::None,
@@ -3955,12 +3955,13 @@ void OP::GetMinShaderModelAndMask(OpCode C, bool bWithTranslation,
39553955
// LinAlgMatrixQueryAccumulatorLayout=2147483670, LinAlgMatVecMul=2147483673,
39563956
// LinAlgMatVecMulAdd=2147483674,
39573957
// LinAlgMatrixAccumulateToDescriptor=2147483675,
3958-
// LinAlgMatrixOuterProduct=2147483677, DebugBreak=2147483681,
3959-
// IsDebuggerPresent=2147483682
3958+
// LinAlgMatrixOuterProduct=2147483677, LinAlgConvert=2147483678,
3959+
// DebugBreak=2147483681, IsDebuggerPresent=2147483682
39603960
if (op == 2147483648 || (2147483652 <= op && op <= 2147483653) ||
39613961
(2147483656 <= op && op <= 2147483657) || op == 2147483662 ||
39623962
op == 2147483670 || (2147483673 <= op && op <= 2147483675) ||
3963-
op == 2147483677 || (2147483681 <= op && op <= 2147483682)) {
3963+
(2147483677 <= op && op <= 2147483678) ||
3964+
(2147483681 <= op && op <= 2147483682)) {
39643965
major = 6;
39653966
minor = 10;
39663967
return;
@@ -6636,13 +6637,15 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
66366637
A(EXT(0));
66376638
A(pI32);
66386639
A(EXT(1));
6640+
A(pI1);
66396641
A(EXT(2));
66406642
A(pI32);
66416643
break;
66426644
case OpCode::LinAlgMatVecMulAdd:
66436645
A(EXT(0));
66446646
A(pI32);
66456647
A(EXT(1));
6648+
A(pI1);
66466649
A(EXT(2));
66476650
A(pI32);
66486651
A(EXT(3));
@@ -6673,17 +6676,20 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
66736676
A(EXT(1));
66746677
A(EXT(2));
66756678
break;
6676-
6677-
//
6678-
case OpCode::ReservedE1:
6679-
A(pV);
6679+
case OpCode::LinAlgConvert:
6680+
A(EXT(0));
6681+
A(pI32);
6682+
A(EXT(1));
6683+
A(pI32);
66806684
A(pI32);
66816685
break;
6682-
case OpCode::ReservedE2:
6686+
6687+
//
6688+
case OpCode::ReservedE0:
66836689
A(pV);
66846690
A(pI32);
66856691
break;
6686-
case OpCode::ReservedE3:
6692+
case OpCode::ReservedE1:
66876693
A(pV);
66886694
A(pI32);
66896695
break;
@@ -7002,9 +7008,8 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
70027008
case OpCode::GetGroupWaveCount:
70037009
case OpCode::ClusterID:
70047010
case OpCode::LinAlgMatrixQueryAccumulatorLayout:
7011+
case OpCode::ReservedE0:
70057012
case OpCode::ReservedE1:
7006-
case OpCode::ReservedE2:
7007-
case OpCode::ReservedE3:
70087013
case OpCode::DebugBreak:
70097014
case OpCode::IsDebuggerPresent:
70107015
return Type::getVoidTy(Ctx);
@@ -7047,6 +7052,7 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
70477052
case OpCode::LinAlgFillMatrix:
70487053
case OpCode::LinAlgCopyConvertMatrix:
70497054
case OpCode::LinAlgMatrixGetElement:
7055+
case OpCode::LinAlgConvert:
70507056
if (FT->getNumParams() < 2)
70517057
return nullptr;
70527058
return llvm::StructType::get(Ctx,
@@ -7060,6 +7066,7 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
70607066
{FT->getReturnType(), FT->getParamType(1)->getPointerElementType()});
70617067

70627068
case OpCode::LinAlgMatrixSetElement:
7069+
case OpCode::LinAlgMatVecMul:
70637070
if (FT->getNumParams() < 4)
70647071
return nullptr;
70657072
return llvm::StructType::get(
@@ -7075,19 +7082,18 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
70757082

70767083
case OpCode::LinAlgMatrixMultiply:
70777084
case OpCode::LinAlgMatrixAccumulate:
7078-
case OpCode::LinAlgMatVecMul:
70797085
case OpCode::LinAlgMatrixOuterProduct:
70807086
if (FT->getNumParams() < 3)
70817087
return nullptr;
70827088
return llvm::StructType::get(
70837089
Ctx, {FT->getReturnType(), FT->getParamType(1), FT->getParamType(2)});
70847090

70857091
case OpCode::LinAlgMatVecMulAdd:
7086-
if (FT->getNumParams() < 5)
7092+
if (FT->getNumParams() < 6)
70877093
return nullptr;
70887094
return llvm::StructType::get(Ctx,
70897095
{FT->getReturnType(), FT->getParamType(1),
7090-
FT->getParamType(2), FT->getParamType(4)});
7096+
FT->getParamType(3), FT->getParamType(5)});
70917097

70927098
// OPCODE-OLOAD-TYPES:END
70937099
default:

0 commit comments

Comments
 (0)