diff --git a/include/dxc/DXIL/DxilInstructions.h b/include/dxc/DXIL/DxilInstructions.h index 8c48202ce0..941eab6474 100644 --- a/include/dxc/DXIL/DxilInstructions.h +++ b/include/dxc/DXIL/DxilInstructions.h @@ -10651,14 +10651,14 @@ struct DxilInst_LinAlgMatrixLoadFromMemory { bool requiresUniformInputs() const { return false; } // Operand indexes enum OperandIdx { - arg_groupsharedArr = 1, + arg_memory = 1, arg_offset = 2, arg_stride = 3, arg_layout = 4, }; // Accessors - llvm::Value *get_groupsharedArr() const { return Instr->getOperand(1); } - void set_groupsharedArr(llvm::Value *val) { Instr->setOperand(1, val); } + llvm::Value *get_memory() const { return Instr->getOperand(1); } + void set_memory(llvm::Value *val) { Instr->setOperand(1, val); } llvm::Value *get_offset() const { return Instr->getOperand(2); } void set_offset(llvm::Value *val) { Instr->setOperand(2, val); } llvm::Value *get_stride() const { return Instr->getOperand(3); } @@ -10854,7 +10854,7 @@ struct DxilInst_LinAlgMatrixStoreToMemory { // Operand indexes enum OperandIdx { arg_matrix = 1, - arg_groupsharedArr = 2, + arg_memory = 2, arg_offset = 3, arg_stride = 4, arg_layout = 5, @@ -10862,8 +10862,8 @@ struct DxilInst_LinAlgMatrixStoreToMemory { // Accessors llvm::Value *get_matrix() const { return Instr->getOperand(1); } void set_matrix(llvm::Value *val) { Instr->setOperand(1, val); } - llvm::Value *get_groupsharedArr() const { return Instr->getOperand(2); } - void set_groupsharedArr(llvm::Value *val) { Instr->setOperand(2, val); } + llvm::Value *get_memory() const { return Instr->getOperand(2); } + void set_memory(llvm::Value *val) { Instr->setOperand(2, val); } llvm::Value *get_offset() const { return Instr->getOperand(3); } void set_offset(llvm::Value *val) { Instr->setOperand(3, val); } llvm::Value *get_stride() const { return Instr->getOperand(4); } @@ -11091,7 +11091,7 @@ struct DxilInst_LinAlgMatrixAccumulateToMemory { // Operand indexes enum OperandIdx { arg_matrix = 1, - arg_groupsharedArr = 2, + arg_memory = 2, arg_offset = 3, arg_stride = 4, arg_layout = 5, @@ -11099,8 +11099,8 @@ struct DxilInst_LinAlgMatrixAccumulateToMemory { // Accessors llvm::Value *get_matrix() const { return Instr->getOperand(1); } void set_matrix(llvm::Value *val) { Instr->setOperand(1, val); } - llvm::Value *get_groupsharedArr() const { return Instr->getOperand(2); } - void set_groupsharedArr(llvm::Value *val) { Instr->setOperand(2, val); } + llvm::Value *get_memory() const { return Instr->getOperand(2); } + void set_memory(llvm::Value *val) { Instr->setOperand(2, val); } llvm::Value *get_offset() const { return Instr->getOperand(3); } void set_offset(llvm::Value *val) { Instr->setOperand(3, val); } llvm::Value *get_stride() const { return Instr->getOperand(4); } diff --git a/lib/DXIL/DxilOperations.cpp b/lib/DXIL/DxilOperations.cpp index 4138b3d930..ffff4eccd9 100644 --- a/lib/DXIL/DxilOperations.cpp +++ b/lib/DXIL/DxilOperations.cpp @@ -3166,26 +3166,33 @@ const char *OP::GetOverloadTypeName(unsigned TypeSlot) { StringRef OP::GetTypeName(Type *Ty, SmallVectorImpl &Storage) { DXASSERT(!Ty->isVoidTy(), "must not pass void type here"); unsigned TypeSlot = OP::GetTypeSlot(Ty); + if (TypeSlot < TS_BasicCount) { return GetOverloadTypeName(TypeSlot); - } else if (TypeSlot == TS_UDT) { + } + + switch (TypeSlot) { + case TS_UDT: { if (Ty->isPointerTy()) Ty = Ty->getPointerElementType(); StructType *ST = cast(Ty); return ST->getStructName(); - } else if (TypeSlot == TS_Object) { + } + case TS_Object: { StructType *ST = cast(Ty); if (dxilutil::IsHLSLLinAlgMatrixType(Ty)) return (Twine("m") + Twine(dxilutil::GetHLSLLinAlgMatrixTypeMangling(ST))) .toStringRef(Storage); return ST->getStructName(); - } else if (TypeSlot == TS_Vector) { + } + case TS_Vector: { VectorType *VecTy = cast(Ty); return (Twine("v") + Twine(VecTy->getNumElements()) + Twine( GetOverloadTypeName(OP::GetTypeSlot(VecTy->getElementType())))) .toStringRef(Storage); - } else if (TypeSlot == TS_Extended) { + } + case TS_Extended: { DXASSERT(isa(Ty), "otherwise, extended overload type not wrapped in struct type."); StructType *ST = cast(Ty); @@ -3200,11 +3207,14 @@ StringRef OP::GetTypeName(Type *Ty, SmallVectorImpl &Storage) { OS << GetTypeName(ST->getElementType(I), TempStr); } return OS.str(); - } else { - raw_svector_ostream OS(Storage); - Ty->print(OS); - return OS.str(); } + default: + break; + } + + raw_svector_ostream OS(Storage); + Ty->print(OS); + return OS.str(); } StringRef OP::ConstructOverloadName(Type *Ty, DXIL::OpCode opCode, @@ -4314,9 +4324,10 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { #define VEC2(_y) A(VectorType::get(_y, 2)) #define VEC4(_y) A(GetStructVectorType(4, _y)) #define VEC9(_y) A(VectorType::get(_y, 9)) +#define TGSM(_y) A(PointerType::get(_y, DXIL::kTGSMAddrSpace)) // Extended Overload types are wrapped in an anonymous struct -#define EXT(_y) A(cast(pOverloadType)->getElementType(_y)) +#define EXT(_y) cast(pOverloadType)->getElementType(_y) /* hctdb_instrhelp.get_oloads_funcs()*/ switch (opCode) { // return opCode @@ -6427,9 +6438,9 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { // Linear Algebra Operations case OpCode::MatVecMul: - EXT(0); + A(EXT(0)); A(pI32); - EXT(1); + A(EXT(1)); A(pI1); A(pI32); A(pRes); @@ -6443,9 +6454,9 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { A(pI1); break; case OpCode::MatVecMulAdd: - EXT(0); + A(EXT(0)); A(pI32); - EXT(1); + A(EXT(1)); A(pI1); A(pI32); A(pRes); @@ -6464,8 +6475,8 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { case OpCode::OuterProductAccumulate: A(pV); A(pI32); - EXT(0); - EXT(1); + A(EXT(0)); + A(EXT(1)); A(pRes); A(pI32); A(pI32); @@ -6568,21 +6579,21 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { // Linear Algebra Operations case OpCode::LinAlgMatrixMultiplyAccumulate: - EXT(0); + A(EXT(0)); A(pI32); - EXT(1); - EXT(2); - EXT(3); + A(EXT(1)); + A(EXT(2)); + A(EXT(3)); break; case OpCode::LinAlgFillMatrix: - EXT(0); + A(EXT(0)); A(pI32); - EXT(1); + A(EXT(1)); break; case OpCode::LinAlgCopyConvertMatrix: - EXT(0); + A(EXT(0)); A(pI32); - EXT(1); + A(EXT(1)); A(pI1); break; case OpCode::LinAlgMatrixLoadFromDescriptor: @@ -6594,9 +6605,9 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { A(pI32); break; case OpCode::LinAlgMatrixLoadFromMemory: - EXT(0); + A(EXT(0)); A(pI32); - EXT(1); + TGSM(EXT(1)); A(pI32); A(pI32); A(pI32); @@ -6613,17 +6624,17 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { A(pI32); break; case OpCode::LinAlgMatrixGetElement: - EXT(0); + A(EXT(0)); A(pI32); - EXT(1); + A(EXT(1)); A(pI32); break; case OpCode::LinAlgMatrixSetElement: - EXT(0); + A(EXT(0)); A(pI32); - EXT(1); + A(EXT(1)); A(pI32); - EXT(2); + A(EXT(2)); break; case OpCode::LinAlgMatrixStoreToDescriptor: A(pV); @@ -6637,8 +6648,8 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { case OpCode::LinAlgMatrixStoreToMemory: A(pV); A(pI32); - EXT(0); - EXT(1); + A(EXT(0)); + TGSM(EXT(1)); A(pI32); A(pI32); A(pI32); @@ -6648,31 +6659,31 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { A(pI32); break; case OpCode::LinAlgMatrixMultiply: - EXT(0); + A(EXT(0)); A(pI32); - EXT(1); - EXT(2); + A(EXT(1)); + A(EXT(2)); break; case OpCode::LinAlgMatrixAccumulate: - EXT(0); + A(EXT(0)); A(pI32); - EXT(1); - EXT(2); + A(EXT(1)); + A(EXT(2)); break; case OpCode::LinAlgMatVecMul: - EXT(0); + A(EXT(0)); A(pI32); - EXT(1); - EXT(2); + A(EXT(1)); + A(EXT(2)); A(pI32); break; case OpCode::LinAlgMatVecMulAdd: - EXT(0); + A(EXT(0)); A(pI32); - EXT(1); - EXT(2); + A(EXT(1)); + A(EXT(2)); A(pI32); - EXT(3); + A(EXT(3)); A(pI32); break; case OpCode::LinAlgMatrixAccumulateToDescriptor: @@ -6687,17 +6698,17 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { case OpCode::LinAlgMatrixAccumulateToMemory: A(pV); A(pI32); - EXT(0); - EXT(1); + A(EXT(0)); + TGSM(EXT(1)); A(pI32); A(pI32); A(pI32); break; case OpCode::LinAlgMatrixOuterProduct: - EXT(0); + A(EXT(0)); A(pI32); - EXT(1); - EXT(2); + A(EXT(1)); + A(EXT(2)); break; // @@ -7064,7 +7075,6 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) { case OpCode::MatVecMulAdd: case OpCode::LinAlgFillMatrix: case OpCode::LinAlgCopyConvertMatrix: - case OpCode::LinAlgMatrixLoadFromMemory: case OpCode::LinAlgMatrixGetElement: if (FT->getNumParams() < 2) return nullptr; @@ -7072,8 +7082,6 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) { {FT->getReturnType(), FT->getParamType(1)}); case OpCode::OuterProductAccumulate: - case OpCode::LinAlgMatrixStoreToMemory: - case OpCode::LinAlgMatrixAccumulateToMemory: if (FT->getNumParams() < 3) return nullptr; return llvm::StructType::get(Ctx, @@ -7086,12 +7094,27 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) { {FT->getReturnType(), FT->getParamType(1), FT->getParamType(2), FT->getParamType(3)}); + case OpCode::LinAlgMatrixLoadFromMemory: + if (FT->getNumParams() < 2) + return nullptr; + return llvm::StructType::get( + Ctx, + {FT->getReturnType(), FT->getParamType(1)->getPointerElementType()}); + case OpCode::LinAlgMatrixSetElement: if (FT->getNumParams() < 4) return nullptr; return llvm::StructType::get( Ctx, {FT->getReturnType(), FT->getParamType(1), FT->getParamType(3)}); + case OpCode::LinAlgMatrixStoreToMemory: + case OpCode::LinAlgMatrixAccumulateToMemory: + if (FT->getNumParams() < 3) + return nullptr; + return llvm::StructType::get( + Ctx, + {FT->getParamType(1), FT->getParamType(2)->getPointerElementType()}); + case OpCode::LinAlgMatrixMultiply: case OpCode::LinAlgMatrixAccumulate: case OpCode::LinAlgMatVecMul: diff --git a/lib/HLSL/HLOperationLower.cpp b/lib/HLSL/HLOperationLower.cpp index 4f22a4598d..6d718257d4 100644 --- a/lib/HLSL/HLOperationLower.cpp +++ b/lib/HLSL/HLOperationLower.cpp @@ -7226,6 +7226,60 @@ Value *TranslateLinAlgCopyConvertMatrix(CallInst *CI, IntrinsicOp IOP, return nullptr; } +Value *TranslateLinAlgMatrixLoadFromMemory( + CallInst *CI, IntrinsicOp IOP, OP::OpCode OpCode, + HLOperationLowerHelper &Helper, HLObjectOperationLowerHelper *ObjHelper, + bool &Translated) { + hlsl::OP *HlslOp = &Helper.hlslOP; + IRBuilder<> Builder(CI); + + Value *MatrixPtr = CI->getArgOperand(1); + DXASSERT_NOMSG(isa(MatrixPtr->getType())); + Type *MatrixType = MatrixPtr->getType()->getPointerElementType(); + + Value *Arr = CI->getArgOperand(2); + Value *Offset = CI->getArgOperand(3); + Value *Stride = CI->getArgOperand(4); + Value *Layout = CI->getArgOperand(5); + + Value *Zero = Builder.getInt32(0); + Value *ArrPtr = Builder.CreateGEP(Arr, {Zero, Zero}); + Type *ArrEltTy = ArrPtr->getType()->getPointerElementType(); + + Constant *OpArg = HlslOp->GetU32Const((unsigned)OpCode); + Function *DxilFunc = HlslOp->GetOpFunc(OpCode, {MatrixType, ArrEltTy}); + + Value *Matrix = + Builder.CreateCall(DxilFunc, {OpArg, ArrPtr, Offset, Stride, Layout}); + Builder.CreateStore(Matrix, MatrixPtr); + + return nullptr; +} + +Value *TranslateLinAlgMatrixAccumStoreToMemory( + CallInst *CI, IntrinsicOp IOP, OP::OpCode OpCode, + HLOperationLowerHelper &Helper, HLObjectOperationLowerHelper *ObjHelper, + bool &Translated) { + hlsl::OP *HlslOp = &Helper.hlslOP; + IRBuilder<> Builder(CI); + + Value *Matrix = CI->getArgOperand(1); + Value *Arr = CI->getArgOperand(2); + Value *Offset = CI->getArgOperand(3); + Value *Stride = CI->getArgOperand(4); + Value *Layout = CI->getArgOperand(5); + + Value *Zero = Builder.getInt32(0); + Value *ArrPtr = Builder.CreateGEP(Arr, {Zero, Zero}); + Type *ArrEltTy = ArrPtr->getType()->getPointerElementType(); + + Constant *OpArg = HlslOp->GetU32Const((unsigned)OpCode); + Function *DxilFunc = HlslOp->GetOpFunc(OpCode, {Matrix->getType(), ArrEltTy}); + + return Builder.CreateCall(DxilFunc, + {OpArg, Matrix, ArrPtr, Offset, Stride, Layout}); +} + } // namespace // Lower table. @@ -7989,14 +8043,16 @@ constexpr IntrinsicLower gLowerTable[] = { {IntrinsicOp::IOP___builtin_LinAlg_MatrixLoadFromDescriptor, TranslateLinAlgMatrixLoadFromDescriptor, DXIL::OpCode::LinAlgMatrixLoadFromDescriptor}, - {IntrinsicOp::IOP___builtin_LinAlg_MatrixLoadFromMemory, EmptyLower, + {IntrinsicOp::IOP___builtin_LinAlg_MatrixLoadFromMemory, + TranslateLinAlgMatrixLoadFromMemory, DXIL::OpCode::LinAlgMatrixLoadFromMemory}, {IntrinsicOp::IOP___builtin_LinAlg_MatrixSetElement, TranslateLinAlgMatrixSetElement, DXIL::OpCode::LinAlgMatrixSetElement}, {IntrinsicOp::IOP___builtin_LinAlg_MatrixStoreToDescriptor, TranslateLinAlgMatrixAccumStoreToDescriptor, DXIL::OpCode::LinAlgMatrixStoreToDescriptor}, - {IntrinsicOp::IOP___builtin_LinAlg_MatrixStoreToMemory, EmptyLower, + {IntrinsicOp::IOP___builtin_LinAlg_MatrixStoreToMemory, + TranslateLinAlgMatrixAccumStoreToMemory, DXIL::OpCode::LinAlgMatrixStoreToMemory}, {IntrinsicOp::IOP___builtin_LinAlg_MatrixAccumulate, TranslateLinAlgMatrixAccumulate, DXIL::OpCode::LinAlgMatrixAccumulate}, @@ -8010,7 +8066,8 @@ constexpr IntrinsicLower gLowerTable[] = { {IntrinsicOp::IOP___builtin_LinAlg_MatrixAccumulateToDescriptor, TranslateLinAlgMatrixAccumStoreToDescriptor, DXIL::OpCode::LinAlgMatrixAccumulateToDescriptor}, - {IntrinsicOp::IOP___builtin_LinAlg_MatrixAccumulateToMemory, EmptyLower, + {IntrinsicOp::IOP___builtin_LinAlg_MatrixAccumulateToMemory, + TranslateLinAlgMatrixAccumStoreToMemory, DXIL::OpCode::LinAlgMatrixAccumulateToMemory}, {IntrinsicOp::IOP___builtin_LinAlg_MatrixOuterProduct, TranslateLinAlgMatrixOuterProduct, DXIL::OpCode::LinAlgMatrixOuterProduct}, diff --git a/tools/clang/lib/Sema/SemaHLSL.cpp b/tools/clang/lib/Sema/SemaHLSL.cpp index 6c95fdfa1e..002ec16a85 100644 --- a/tools/clang/lib/Sema/SemaHLSL.cpp +++ b/tools/clang/lib/Sema/SemaHLSL.cpp @@ -1957,6 +1957,9 @@ ParamModsFromIntrinsicArg(const HLSL_INTRINSIC_ARGUMENT *pArg) { } if (pArg->qwUsage == AR_QUAL_REF) return hlsl::ParameterModifier(hlsl::ParameterModifier::Kind::Ref); + // TODO: https://github.com/microsoft/DirectXShaderCompiler/issues/8270 + if (pArg->qwUsage == AR_QUAL_GROUPSHARED) + return hlsl::ParameterModifier(hlsl::ParameterModifier::Kind::In); DXASSERT(qwUsage & AR_QUAL_IN, "else usage is incorrect"); return hlsl::ParameterModifier(hlsl::ParameterModifier::Kind::In); } diff --git a/tools/clang/test/CodeGenDXIL/hlsl/linalg/builtins/matrixaccumulatetomemory/nominal.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/linalg/builtins/matrixaccumulatetomemory/nominal.hlsl new file mode 100644 index 0000000000..f05366d62f --- /dev/null +++ b/tools/clang/test/CodeGenDXIL/hlsl/linalg/builtins/matrixaccumulatetomemory/nominal.hlsl @@ -0,0 +1,14 @@ +// REQUIRES: dxil-1-10 +// RUN: %dxc -T cs_6_10 -HV 202x -E main %s | FileCheck %s + +// CHECK: @{{.*}} = external addrspace(3) global [64 x float] +groupshared float SharedArr[64]; + +[numthreads(4,1,1)] +void main() { + // CHECK-LABEL: define void @main() + + // CHECK: call void @dx.op.linAlgMatrixAccumulateToMemory.mC4M5N4U1S2.f32(i32 -2147483620, %dx.types.LinAlgMatrixC4M5N4U1S2 {{.*}}, float addrspace(3)* getelementptr {{.*}}, i32 1, i32 2, i32 3) ; LinAlgMatrixAccumulateToMemory(matrix,memory,offset,stride,layout) + __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat; + __builtin_LinAlg_MatrixAccumulateToMemory(mat, SharedArr, 1, 2, 3); +} diff --git a/tools/clang/test/CodeGenDXIL/hlsl/linalg/builtins/matrixloadfrommemory/nominal.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/linalg/builtins/matrixloadfrommemory/nominal.hlsl new file mode 100644 index 0000000000..9c1e8303b2 --- /dev/null +++ b/tools/clang/test/CodeGenDXIL/hlsl/linalg/builtins/matrixloadfrommemory/nominal.hlsl @@ -0,0 +1,14 @@ +// REQUIRES: dxil-1-10 +// RUN: %dxc -T cs_6_10 -HV 202x -E main %s | FileCheck %s + +// CHECK: @{{.*}} = external addrspace(3) global [64 x float] +groupshared float SharedArr[64]; + +[numthreads(4,1,1)] +void main() { + // CHECK-LABEL: define void @main() + + // CHECK: call %dx.types.LinAlgMatrixC4M5N4U1S2 @dx.op.linAlgMatrixLoadFromMemory.mC4M5N4U1S2.f32(i32 -2147483633, float addrspace(3)* getelementptr {{.*}}, i32 1, i32 2, i32 3) ; LinAlgMatrixLoadFromMemory(memory,offset,stride,layout) + __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat; + __builtin_LinAlg_MatrixLoadFromMemory(mat, SharedArr, 1, 2, 3); +} diff --git a/tools/clang/test/CodeGenDXIL/hlsl/linalg/builtins/matrixstoretomemory/nominal.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/linalg/builtins/matrixstoretomemory/nominal.hlsl new file mode 100644 index 0000000000..07a4fa38e5 --- /dev/null +++ b/tools/clang/test/CodeGenDXIL/hlsl/linalg/builtins/matrixstoretomemory/nominal.hlsl @@ -0,0 +1,14 @@ +// REQUIRES: dxil-1-10 +// RUN: %dxc -T cs_6_10 -HV 202x -E main %s | FileCheck %s + +// CHECK: @{{.*}} = external addrspace(3) global [64 x float] +groupshared float SharedArr[64]; + +[numthreads(4,1,1)] +void main() { + // CHECK-LABEL: define void @main() + + // CHECK: call void @dx.op.linAlgMatrixStoreToMemory.mC4M5N4U1S2.f32(i32 -2147483627, %dx.types.LinAlgMatrixC4M5N4U1S2 {{.*}}, float addrspace(3)* getelementptr {{.*}}, i32 1, i32 2, i32 3) ; LinAlgMatrixStoreToMemory(matrix,memory,offset,stride,layout) + __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat; + __builtin_LinAlg_MatrixStoreToMemory(mat, SharedArr, 1, 2, 3); +} diff --git a/tools/clang/test/SemaHLSL/hlsl/linalg/builtins/matrixaccumulatetomemory/ast.hlsl b/tools/clang/test/SemaHLSL/hlsl/linalg/builtins/matrixaccumulatetomemory/ast.hlsl new file mode 100644 index 0000000000..d300796b67 --- /dev/null +++ b/tools/clang/test/SemaHLSL/hlsl/linalg/builtins/matrixaccumulatetomemory/ast.hlsl @@ -0,0 +1,20 @@ +// REQUIRES: dxil-1-10 +// RUN: %dxc -T lib_6_10 -E main %s -ast-dump-implicit | FileCheck %s + +// CHECK: FunctionDecl {{.*}} implicit used __builtin_LinAlg_MatrixAccumulateToMemory 'void (__builtin_LinAlgMatrix {{.*}}, float const __attribute__((address_space(3))) (&)[64], unsigned int, unsigned int, unsigned int)' extern +// CHECK-NEXT: ParmVarDecl {{.*}} matrix '__builtin_LinAlgMatrix {{.*}}' +// CHECK-NEXT: ParmVarDecl {{.*}} memory 'float const __attribute__((address_space(3))) (&)[64]' +// CHECK-NEXT: ParmVarDecl {{.*}} offset 'unsigned int' +// CHECK-NEXT: ParmVarDecl {{.*}} stride 'unsigned int' +// CHECK-NEXT: ParmVarDecl {{.*}} layout 'unsigned int' +// CHECK-NEXT: HLSLIntrinsicAttr {{.*}} Implicit "op" "" 420 +// CHECK-NEXT: AvailabilityAttr {{.*}} Implicit 6.10 0 0 "" + +groupshared float SharedArr[64]; + +[shader("compute")] +[numthreads(1,1,1)] +void main() { + __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat; + __builtin_LinAlg_MatrixAccumulateToMemory(mat, SharedArr, 0, 0, 0); +} diff --git a/tools/clang/test/SemaHLSL/hlsl/linalg/builtins/matrixaccumulatetomemory/unavailable_pre_sm610.hlsl b/tools/clang/test/SemaHLSL/hlsl/linalg/builtins/matrixaccumulatetomemory/unavailable_pre_sm610.hlsl new file mode 100644 index 0000000000..e5a9ea4895 --- /dev/null +++ b/tools/clang/test/SemaHLSL/hlsl/linalg/builtins/matrixaccumulatetomemory/unavailable_pre_sm610.hlsl @@ -0,0 +1,11 @@ +// RUN: %dxc -T cs_6_9 -HV 202x -E main %s -verify + +groupshared float SharedArr[64]; + +[numthreads(4,1,1)] +void main() { + __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat; + + // expected-error@+1{{intrinsic __builtin_LinAlg_MatrixAccumulateToMemory potentially used by ''main'' requires shader model 6.10 or greater}} + __builtin_LinAlg_MatrixAccumulateToMemory(mat, SharedArr, 0, 0, 0); +} diff --git a/tools/clang/test/SemaHLSL/hlsl/linalg/builtins/matrixloadfrommemory/ast.hlsl b/tools/clang/test/SemaHLSL/hlsl/linalg/builtins/matrixloadfrommemory/ast.hlsl new file mode 100644 index 0000000000..3ac0de3880 --- /dev/null +++ b/tools/clang/test/SemaHLSL/hlsl/linalg/builtins/matrixloadfrommemory/ast.hlsl @@ -0,0 +1,20 @@ +// REQUIRES: dxil-1-10 +// RUN: %dxc -T lib_6_10 -E main %s -ast-dump-implicit | FileCheck %s + +// CHECK: FunctionDecl {{.*}} implicit used __builtin_LinAlg_MatrixLoadFromMemory 'void (__builtin_LinAlgMatrix {{.*}}, float const __attribute__((address_space(3))) (&)[64], unsigned int, unsigned int, unsigned int)' extern +// CHECK-NEXT: ParmVarDecl {{.*}} ret '__builtin_LinAlgMatrix {{.*}}' +// CHECK-NEXT: ParmVarDecl {{.*}} memory 'float const __attribute__((address_space(3))) (&)[64]' +// CHECK-NEXT: ParmVarDecl {{.*}} offset 'unsigned int' +// CHECK-NEXT: ParmVarDecl {{.*}} stride 'unsigned int' +// CHECK-NEXT: ParmVarDecl {{.*}} layout 'unsigned int' +// CHECK-NEXT: HLSLIntrinsicAttr {{.*}} Implicit "op" "" 411 +// CHECK-NEXT: AvailabilityAttr {{.*}} Implicit 6.10 0 0 "" + +groupshared float SharedArr[64]; + +[shader("compute")] +[numthreads(1,1,1)] +void main() { + __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat; + __builtin_LinAlg_MatrixLoadFromMemory(mat, SharedArr, 0, 0, 0); +} diff --git a/tools/clang/test/SemaHLSL/hlsl/linalg/builtins/matrixloadfrommemory/unavailable_pre_sm610.hlsl b/tools/clang/test/SemaHLSL/hlsl/linalg/builtins/matrixloadfrommemory/unavailable_pre_sm610.hlsl new file mode 100644 index 0000000000..d8472ad92b --- /dev/null +++ b/tools/clang/test/SemaHLSL/hlsl/linalg/builtins/matrixloadfrommemory/unavailable_pre_sm610.hlsl @@ -0,0 +1,11 @@ +// RUN: %dxc -T cs_6_9 -HV 202x -E main %s -verify + +groupshared float SharedArr[64]; + +[numthreads(4,1,1)] +void main() { + __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat; + + // expected-error@+1{{intrinsic __builtin_LinAlg_MatrixLoadFromMemory potentially used by ''main'' requires shader model 6.10 or greater}} + __builtin_LinAlg_MatrixLoadFromMemory(mat, SharedArr, 0, 0, 0); +} diff --git a/tools/clang/test/SemaHLSL/hlsl/linalg/builtins/matrixstoretomemory/ast.hlsl b/tools/clang/test/SemaHLSL/hlsl/linalg/builtins/matrixstoretomemory/ast.hlsl new file mode 100644 index 0000000000..c726d119eb --- /dev/null +++ b/tools/clang/test/SemaHLSL/hlsl/linalg/builtins/matrixstoretomemory/ast.hlsl @@ -0,0 +1,20 @@ +// REQUIRES: dxil-1-10 +// RUN: %dxc -T lib_6_10 -E main %s -ast-dump-implicit | FileCheck %s + +// CHECK: FunctionDecl {{.*}} implicit used __builtin_LinAlg_MatrixStoreToMemory 'void (__builtin_LinAlgMatrix {{.*}}, float const __attribute__((address_space(3))) (&)[64], unsigned int, unsigned int, unsigned int)' extern +// CHECK-NEXT: ParmVarDecl {{.*}} matrix '__builtin_LinAlgMatrix {{.*}}' +// CHECK-NEXT: ParmVarDecl {{.*}} memory 'float const __attribute__((address_space(3))) (&)[64]' +// CHECK-NEXT: ParmVarDecl {{.*}} offset 'unsigned int' +// CHECK-NEXT: ParmVarDecl {{.*}} stride 'unsigned int' +// CHECK-NEXT: ParmVarDecl {{.*}} layout 'unsigned int' +// CHECK-NEXT: HLSLIntrinsicAttr {{.*}} Implicit "op" "" 414 +// CHECK-NEXT: AvailabilityAttr {{.*}} Implicit 6.10 0 0 "" + +groupshared float SharedArr[64]; + +[shader("compute")] +[numthreads(1,1,1)] +void main() { + __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat; + __builtin_LinAlg_MatrixStoreToMemory(mat, SharedArr, 0, 0, 0); +} diff --git a/tools/clang/test/SemaHLSL/hlsl/linalg/builtins/matrixstoretomemory/unavailable_pre_sm610.hlsl b/tools/clang/test/SemaHLSL/hlsl/linalg/builtins/matrixstoretomemory/unavailable_pre_sm610.hlsl new file mode 100644 index 0000000000..d3468a2a02 --- /dev/null +++ b/tools/clang/test/SemaHLSL/hlsl/linalg/builtins/matrixstoretomemory/unavailable_pre_sm610.hlsl @@ -0,0 +1,11 @@ +// RUN: %dxc -T cs_6_9 -HV 202x -E main %s -verify + +groupshared float SharedArr[64]; + +[numthreads(4,1,1)] +void main() { + __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat; + + // expected-error@+1{{intrinsic __builtin_LinAlg_MatrixStoreToMemory potentially used by ''main'' requires shader model 6.10 or greater}} + __builtin_LinAlg_MatrixStoreToMemory(mat, SharedArr, 0, 0, 0); +} diff --git a/tools/clang/test/SemaHLSL/hlsl/linalg/builtins/stage-errors.hlsl b/tools/clang/test/SemaHLSL/hlsl/linalg/builtins/stage-errors.hlsl index fbec113e81..c9ebd7adf8 100644 --- a/tools/clang/test/SemaHLSL/hlsl/linalg/builtins/stage-errors.hlsl +++ b/tools/clang/test/SemaHLSL/hlsl/linalg/builtins/stage-errors.hlsl @@ -8,8 +8,12 @@ // RUN: %dxc -T lib_6_10 -DMATRIX_STORE_TO_DESCRIPTOR %s -verify // RUN: %dxc -T lib_6_10 -DMATRIX_LENGTH %s -verify // RUN: %dxc -T lib_6_10 -DMATRIX_ACCUMULATE %s -verify +// RUN: %dxc -T lib_6_10 -DMATRIX_LOAD_FROM_MEMORY %s -verify +// RUN: %dxc -T lib_6_10 -DMATRIX_STORE_TO_MEMORY %s -verify +// RUN: %dxc -T lib_6_10 -DMATRIX_ACCUMULATE_TO_MEMORY %s -verify RWByteAddressBuffer buf; +groupshared float gs_arr[64]; void CallFunction() { @@ -62,6 +66,18 @@ void CallFunction() #define DO_FUNC __builtin_LinAlg_MatrixAccumulate(mat1, mat2, mat3); #endif +#ifdef MATRIX_LOAD_FROM_MEMORY + #define DO_FUNC __builtin_LinAlg_MatrixLoadFromMemory(mat1, gs_arr, 0, 0, 0); +#endif + +#ifdef MATRIX_STORE_TO_MEMORY + #define DO_FUNC __builtin_LinAlg_MatrixStoreToMemory(mat1, gs_arr, 0, 0, 0); +#endif + +#ifdef MATRIX_ACCUMULATE_TO_MEMORY + #define DO_FUNC __builtin_LinAlg_MatrixAccumulateToMemory(mat1, gs_arr, 0, 0, 0); +#endif + // The builtins below are allowed in all stages, if they raise an error // then the test will fail with "saw unexpected diagnostic" uint layout = __builtin_LinAlg_MatrixQueryAccumulatorLayout(); diff --git a/utils/hct/gen_intrin_main.txt b/utils/hct/gen_intrin_main.txt index 49aa2f151b..6395f26a58 100644 --- a/utils/hct/gen_intrin_main.txt +++ b/utils/hct/gen_intrin_main.txt @@ -397,18 +397,17 @@ void [[min_sm=6.10]] __builtin_VectorAccumulate(in LinAlg InputVector, in RWB // LinAlg intrinsics -// TODO: Replace all int GroupSharedMem with groupshared memory void [[min_sm=6.10]] __builtin_LinAlg_FillMatrix(out LinAlgMatrix ret, in numeric value); void [[min_sm=6.10]] __builtin_LinAlg_CopyConvertMatrix(out LinAlgMatrix ret, in LinAlgMatrix source, in bool transpose); void [[min_sm=6.10]] __builtin_LinAlg_MatrixLoadFromDescriptor(out LinAlgMatrix ret, in ByteAddressBuffer buf, in uint offset, in uint stride, in uint layout); void [[min_sm=6.10]] __builtin_LinAlg_MatrixLoadFromDescriptor(out LinAlgMatrix ret, in RWByteAddressBuffer buf, in uint offset, in uint stride, in uint layout); -void [[min_sm=6.10]] __builtin_LinAlg_MatrixLoadFromMemory(out LinAlgMatrix ret, in int GroupSharedMem, in uint offset, in uint stride, in uint layout); +void [[min_sm=6.10]] __builtin_LinAlg_MatrixLoadFromMemory(out LinAlgMatrix ret, groupshared numeric[] memory, in uint offset, in uint stride, in uint layout); uint [[min_sm=6.10]] __builtin_LinAlg_MatrixLength(in LinAlgMatrix matrix); uint<2> [[min_sm=6.10]] __builtin_LinAlg_MatrixGetCoordinate(in LinAlgMatrix matrix, in uint threadLocalIndex); void [[min_sm=6.10]] __builtin_LinAlg_MatrixGetElement(out numeric ret, in LinAlgMatrix matrix, in uint threadLocalIndex); void [[min_sm=6.10]] __builtin_LinAlg_MatrixSetElement(out LinAlgMatrix ret, in LinAlgMatrix matrix, in uint threadLocalIndex, in numeric value); void [[min_sm=6.10]] __builtin_LinAlg_MatrixStoreToDescriptor(in LinAlgMatrix matrix, in RWByteAddressBuffer buf, in uint offset, in uint stride, in uint layout); -void [[min_sm=6.10]] __builtin_LinAlg_MatrixStoreToMemory(in LinAlgMatrix matrix, in int GroupSharedMem, in uint offset, in uint stride, in uint layout); +void [[min_sm=6.10]] __builtin_LinAlg_MatrixStoreToMemory(in LinAlgMatrix matrix, groupshared numeric[] memory, in uint offset, in uint stride, in uint layout); uint [[min_sm=6.10]] __builtin_LinAlg_MatrixQueryAccumulatorLayout(); void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiply(out LinAlgMatrix matrixC, in LinAlgMatrix matrixA, in LinAlgMatrix matrixB); void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiplyAccumulate(out LinAlgMatrix matrixR, in LinAlgMatrix matrixA, in LinAlgMatrix matrixB, in LinAlgMatrix matrixC); @@ -416,7 +415,7 @@ void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulate(out LinAlgMatrix matrixC, void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiply(out numeric<> ret, in LinAlgMatrix mat, in numeric<> input, in uint input_interp); void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiplyAdd(out numeric<> ret, in LinAlgMatrix mat, in numeric<> input, in uint input_interp, in numeric<> bias, in uint bias_interp); void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToDescriptor(in LinAlgMatrix matrix, in RWByteAddressBuffer buf, in uint offset, in uint stride, in uint layout); -void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToMemory(in LinAlgMatrix matrix, in int GroupSharedMem, in uint offset, in uint stride, in uint layout); +void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToMemory(in LinAlgMatrix matrix, groupshared numeric[] memory, in uint offset, in uint stride, in uint layout); void [[min_sm=6.10]] __builtin_LinAlg_MatrixOuterProduct(out LinAlgMatrix ret, in numeric<> vecA, in numeric<> vecB); } namespace diff --git a/utils/hct/hctdb.py b/utils/hct/hctdb.py index 71f035e059..e88834aa62 100644 --- a/utils/hct/hctdb.py +++ b/utils/hct/hctdb.py @@ -51,7 +51,8 @@ # processing. # - "," is used to separate multiple overload dimensions. # - When used, only $x0, $x1, etc. are supported for overloaded parameter -# types. +# types. $x_gs0, $x_gs1, etc work like $xN except the overload will be a +# pointer to groupshared memory. # dxil_all_user_oload_chars must be kept in sync with the indices in # hlsl::OP::TypeSlot in DxilOperations.h. dxil_all_user_oload_chars = "hfd18wiluo<" @@ -295,8 +296,12 @@ def check_extended_oload_ops(self): return next_oload_idx = 0 for i in self.ops: - if i.llvm_type.startswith("$x"): - if i.llvm_type != "$x" + str(next_oload_idx): + # _gs is extra metadata info on the overload. It has no impact on + # the ordering rules so it can be erased for the check. + # $x_gs7 -> $x7 + ty = i.llvm_type.replace("$x_gs", "$x") + if ty.startswith("$x"): + if ty != "$x" + str(next_oload_idx): raise ValueError( "Extended overloads are not sequentially referenced in " f"DXIL op {self.name}: {i.llvm_type} != $x{next_oload_idx}" @@ -6406,13 +6411,12 @@ def populate_ExperimentalOps(self): "LinAlgMatrixLoadFromMemory", "LinAlgMatrixLoadFromMemory", "fills a matrix with data from a groupshared array", - "o,hfwi", # TODO: needs to be updated for groupshared + "o,hfwi", "", [ db_dxil_param(0, "$x0", "", "resulting matrix"), - # TODO: [Ty] * addrspace(4), ; groupshared T[M * N] db_dxil_param( - 2, "$x1", "groupsharedArr", "groupshared array to fill matrix with" + 2, "$x_gs1", "memory", "groupshared array to fill matrix with" ), db_dxil_param(3, "i32", "offset", "starting offset in the array"), db_dxil_param( @@ -6508,14 +6512,13 @@ def populate_ExperimentalOps(self): "LinAlgMatrixStoreToMemory", "LinAlgMatrixStoreToMemory", "stores a matrix to groupshared memory", - "o,hfwi", # TODO: needs to be updated for groupshared + "o,hfwi", "", [ db_dxil_param(0, "v", "", ""), db_dxil_param(2, "$x0", "matrix", "matrix to be stored"), - # TODO: [Ty] * addrspace(4), ; groupshared T[M * N] db_dxil_param( - 3, "$x1", "groupsharedArr", "groupshared array to store into" + 3, "$x_gs1", "memory", "groupshared array to store into" ), db_dxil_param(4, "i32", "offset", "starting offset in the array"), db_dxil_param( @@ -6626,14 +6629,13 @@ def populate_ExperimentalOps(self): "LinAlgMatrixAccumulateToMemory", "LinAlgMatrixAccumulateToMemory", "accumulates a matrix to groupshared memory", - "o,hfwi", # TODO: needs to be updated for groupshared + "o,hfwi", "", [ db_dxil_param(0, "v", "", ""), db_dxil_param(2, "$x0", "matrix", "Accumulator matrix"), - # TODO: [Ty] * addrspace(4), ; groupshared T[M * N] db_dxil_param( - 3, "$x1", "groupsharedArr", "groupshared array to accumulate into" + 3, "$x_gs1", "memory", "groupshared array to accumulate into" ), db_dxil_param(4, "i32", "offset", "starting offset in the array"), db_dxil_param( diff --git a/utils/hct/hctdb_instrhelp.py b/utils/hct/hctdb_instrhelp.py index 5e09578af7..f91012ff75 100644 --- a/utils/hct/hctdb_instrhelp.py +++ b/utils/hct/hctdb_instrhelp.py @@ -644,10 +644,15 @@ def print_opfunc_table(self): "noderecordproperty": "A(nodeRecordProperty);", "hit_object": "A(pHit);", # Extended overload slots, extend as needed: - "$x0": "EXT(0);", - "$x1": "EXT(1);", - "$x2": "EXT(2);", - "$x3": "EXT(3);", + "$x0": "A(EXT(0));", + "$x1": "A(EXT(1));", + "$x2": "A(EXT(2));", + "$x3": "A(EXT(3));", + # Groupshared pointers to extended overloads: + "$x_gs0": "TGSM(EXT(0));", + "$x_gs1": "TGSM(EXT(1));", + "$x_gs2": "TGSM(EXT(2));", + "$x_gs3": "TGSM(EXT(3));", } last_category = None for i in self.db.get_dxil_ops(): @@ -679,6 +684,7 @@ def print_opfunc_oload_type(self): vec_ty = "$vec" gsptr_ty = "$gsptr" extended_ty = "$x" + extended_gs_ty = "$x_gs" last_category = None index_dict = collections.OrderedDict() @@ -846,7 +852,7 @@ def print_opfunc_oload_type(self): # indices the key, and add the opcode to a list of opcodes for that # key. Indices start with 0 for return type, and 1 for the first # function parameter, which is the DXIL OpCode. - indices = [] + indices = [] # (op.pos, unwrap_pointer) pairs for index, op in enumerate(instr.ops): # Skip dxil opcode. if op.pos == 1: @@ -854,8 +860,10 @@ def print_opfunc_oload_type(self): op_type = op.llvm_type if op_type.startswith(extended_ty): + gs_ptr = op_type.startswith(extended_gs_ty) + prefix_len = len(extended_gs_ty) if gs_ptr else len(extended_ty) try: - extended_index = int(op_type[2:]) + extended_index = int(op_type[prefix_len:]) except: raise ValueError( "Error parsing extended operand type " @@ -866,7 +874,7 @@ def print_opfunc_oload_type(self): f"'$x{extended_index}' is not in sequential " + f"order for DXIL op '{instr.name}'" ) - indices.append(op.pos) + indices.append((op.pos, gs_ptr)) if len(indices) != instr.num_oloads: raise ValueError( @@ -875,24 +883,29 @@ def print_opfunc_oload_type(self): ) extended_dict.setdefault(tuple(indices), []).append(instr.name) - def get_type_at_index(index): - if index == 0: - return "FT->getReturnType()" - return f"FT->getParamType({index - 1})" + def get_type_at_index(index, unwrap_pointer): + result = "FT->getReturnType()" + if index > 0: + result = f"FT->getParamType({index - 1})" + if unwrap_pointer: + result = result + "->getPointerElementType()" + return result for index_tuple, opcodes in extended_dict.items(): line = "" for opcode in opcodes: line = line + f"case OpCode::{opcode}:\n" - if index_tuple[-1] > 0: + if index_tuple[-1][0] > 0: line += ( - f" if (FT->getNumParams() < {index_tuple[-1]})\n" + f" if (FT->getNumParams() < {index_tuple[-1][0]})\n" + " return nullptr;\n" ) line += ( " return llvm::StructType::get(Ctx, {" - + ", ".join([get_type_at_index(index) for index in index_tuple]) - + "});\n" + + ", ".join([ + get_type_at_index(index, unwrap_pointer) + for index, unwrap_pointer in index_tuple + ]) + "});\n" ) print(line)