Skip to content

Commit 37338c2

Browse files
authored
[SM6.10] Add LinAlg Matrix attributes and two new attributed matrix types (#8125)
Adds new type attribute `__LinAlg_Matrix_Attributes` that can be attached to `__builtin_LinAlg_Matrix` built-in matrix handle type to specify the matrix component type, dimensions, use and scope. Includes enabling C11++ -style attributes on types in DXC. Adds two new AST types that are used to capture the information from the attribute: - `AttributedLinAlgMatrixType` type is used when the attribute specifies concrete matrix parameters - `DependentAttributedLinAlgMatrixType` is for cases where the type attributes arguments are dependent on template instantiation The `AttributedLinAlgMatrixType` is linked to `LinAlgMatrix` type alias used in `gen_intrin_main.txt` for built-in function APIs. This change also adds the outline of the LinAlg `Matrix` template class to `dx/linalg.h`, including the enumeration types used as arguments to the new attribute and as template parameters of `Matrix`. Note that the `ComponentType` enum in `dx/linalg.h` is defined using values matching the existing internal `ComponentType` enum that was added in SM 6.9 for cooperative vectors. This is different from the current [LinAlg spec](https://github.com/microsoft/hlsl-specs/blob/main/proposals/0035-linalg-matrix.md). The issue to update the spec is microsoft/hlsl-specs#779. Fixes #8122
1 parent 3e6e148 commit 37338c2

34 files changed

Lines changed: 1275 additions & 69 deletions

include/dxc/DXIL/DxilConstants.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,18 @@ enum class ComponentType : uint32_t {
192192
LastEntry
193193
};
194194

195+
enum class MatrixUse : uint32_t {
196+
A = 0,
197+
B = 1,
198+
Accumulator = 2,
199+
};
200+
201+
enum class MatrixScope : uint32_t {
202+
Thread = 0,
203+
Wave = 1,
204+
ThreadGroup = 2,
205+
};
206+
195207
// Must match D3D_INTERPOLATION_MODE
196208
enum class InterpolationMode : uint8_t {
197209
Undefined = 0,

include/dxc/dxcapi.internal.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,17 +130,19 @@ enum LEGAL_INTRINSIC_COMPTYPES {
130130
LICOMPTYPE_HIT_OBJECT = 51,
131131
LICOMPTYPE_RAY_QUERY = 52,
132132

133-
LICOMPTYPE_LINALG = 53, // f32, partial-precision-f32, f16,
133+
LICOMPTYPE_LINALG_MATRIX = 53,
134+
135+
LICOMPTYPE_LINALG = 54, // f32, partial-precision-f32, f16,
134136
// i32, i16, u32, u16,
135137
// int8_4packed, uint8_4packed
136138

137-
LICOMPTYPE_BUILTIN_TRIANGLE_POSITIONS = 54,
139+
LICOMPTYPE_BUILTIN_TRIANGLE_POSITIONS = 55,
138140

139141
#ifdef ENABLE_SPIRV_CODEGEN
140-
LICOMPTYPE_VK_BUFFER_POINTER = 55,
141-
LICOMPTYPE_COUNT = 56
142+
LICOMPTYPE_VK_BUFFER_POINTER = 56,
143+
LICOMPTYPE_COUNT = 57
142144
#else
143-
LICOMPTYPE_COUNT = 55
145+
LICOMPTYPE_COUNT = 56
144146
#endif
145147
};
146148

tools/clang/include/clang/AST/ASTContext.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#ifndef LLVM_CLANG_AST_ASTCONTEXT_H
1616
#define LLVM_CLANG_AST_ASTCONTEXT_H
1717

18+
#include "dxc/DXIL/DxilConstants.h"
1819
#include "clang/AST/ASTTypeTraits.h"
1920
#include "clang/AST/CanonicalType.h"
2021
#include "clang/AST/CommentCommandTraits.h"
@@ -130,6 +131,12 @@ class ASTContext : public RefCountedBase<ASTContext> {
130131
mutable llvm::FoldingSet<AtomicType> AtomicTypes;
131132
llvm::FoldingSet<AttributedType> AttributedTypes;
132133

134+
// HLSL Change Start
135+
llvm::FoldingSet<AttributedLinAlgMatrixType> AttrLinAlgMatrixTypes;
136+
llvm::FoldingSet<DependentAttributedLinAlgMatrixType>
137+
DepAttrLinAlgMatrixTypes;
138+
// HLSL Change End
139+
133140
mutable llvm::FoldingSet<QualifiedTemplateName> QualifiedTemplateNames;
134141
mutable llvm::FoldingSet<DependentTemplateName> DependentTemplateNames;
135142
mutable llvm::FoldingSet<SubstTemplateTemplateParmStorage>
@@ -1156,6 +1163,19 @@ class ASTContext : public RefCountedBase<ASTContext> {
11561163
QualType modifiedType,
11571164
QualType equivalentType);
11581165

1166+
// HLSL Change Start
1167+
QualType getAttributedLinAlgMatrixType(QualType WrappedTy,
1168+
hlsl::DXIL::ComponentType ComponentTy,
1169+
size_t Rows, size_t Cols,
1170+
hlsl::DXIL::MatrixUse Use,
1171+
hlsl::DXIL::MatrixScope Scope);
1172+
1173+
QualType getDependentAttributedLinAlgMatrixType(QualType WrappedTy,
1174+
Expr *ComponentTyExpr,
1175+
Expr *RowsExpr,
1176+
Expr *ColsExpr, Expr *UseExpr,
1177+
Expr *ScopeExpr);
1178+
// HLSL Change End
11591179
QualType getSubstTemplateTypeParmType(const TemplateTypeParmType *Replaced,
11601180
QualType Replacement) const;
11611181
QualType getSubstTemplateTypeParmPackType(

tools/clang/include/clang/AST/DataRecursiveASTVisitor.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,21 @@ DEF_TRAVERSE_TYPE(AutoType, { TRY_TO(TraverseType(T->getDeducedType())); })
900900

901901
DEF_TRAVERSE_TYPE(RecordType, {})
902902
DEF_TRAVERSE_TYPE(EnumType, {})
903+
904+
// HLSL Change Start
905+
DEF_TRAVERSE_TYPE(AttributedLinAlgMatrixType,
906+
{ TRY_TO(TraverseType(T->getWrappedType())); })
907+
908+
DEF_TRAVERSE_TYPE(DependentAttributedLinAlgMatrixType, {
909+
TRY_TO(TraverseType(T->getWrappedType()));
910+
TRY_TO(TraverseStmt(T->getComponentTyExpr()));
911+
TRY_TO(TraverseStmt(T->getRowsExpr()));
912+
TRY_TO(TraverseStmt(T->getColsExpr()));
913+
TRY_TO(TraverseStmt(T->getUseExpr()));
914+
TRY_TO(TraverseStmt(T->getScopeExpr()));
915+
})
916+
// HLSL Change End
917+
903918
DEF_TRAVERSE_TYPE(TemplateTypeParmType, {})
904919
DEF_TRAVERSE_TYPE(SubstTemplateTypeParmType, {})
905920
DEF_TRAVERSE_TYPE(SubstTemplateTypeParmPackType, {})
@@ -1119,6 +1134,19 @@ DEF_TRAVERSE_TYPELOC(AutoType, {
11191134

11201135
DEF_TRAVERSE_TYPELOC(RecordType, {})
11211136
DEF_TRAVERSE_TYPELOC(EnumType, {})
1137+
// HLSL Change Start
1138+
DEF_TRAVERSE_TYPELOC(AttributedLinAlgMatrixType, {
1139+
TRY_TO(TraverseType(TL.getTypePtr()->getWrappedType()));
1140+
})
1141+
DEF_TRAVERSE_TYPELOC(DependentAttributedLinAlgMatrixType, {
1142+
TRY_TO(TraverseType(TL.getTypePtr()->getWrappedType()));
1143+
TRY_TO(TraverseStmt(TL.getTypePtr()->getComponentTyExpr()));
1144+
TRY_TO(TraverseStmt(TL.getTypePtr()->getRowsExpr()));
1145+
TRY_TO(TraverseStmt(TL.getTypePtr()->getColsExpr()));
1146+
TRY_TO(TraverseStmt(TL.getTypePtr()->getUseExpr()));
1147+
TRY_TO(TraverseStmt(TL.getTypePtr()->getScopeExpr()));
1148+
})
1149+
// HLSL Change End
11221150
DEF_TRAVERSE_TYPELOC(TemplateTypeParmType, {})
11231151
DEF_TRAVERSE_TYPELOC(SubstTemplateTypeParmType, {})
11241152
DEF_TRAVERSE_TYPELOC(SubstTemplateTypeParmPackType, {})

tools/clang/include/clang/AST/RecursiveASTVisitor.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -982,6 +982,20 @@ DEF_TRAVERSE_TYPE(InjectedClassNameType, {})
982982
DEF_TRAVERSE_TYPE(AttributedType,
983983
{ TRY_TO(TraverseType(T->getModifiedType())); })
984984

985+
// HLSL Change Start
986+
DEF_TRAVERSE_TYPE(AttributedLinAlgMatrixType,
987+
{ TRY_TO(TraverseType(T->getWrappedType())); })
988+
989+
DEF_TRAVERSE_TYPE(DependentAttributedLinAlgMatrixType, {
990+
TRY_TO(TraverseType(T->getWrappedType()));
991+
TRY_TO(TraverseStmt(T->getComponentTyExpr()));
992+
TRY_TO(TraverseStmt(T->getRowsExpr()));
993+
TRY_TO(TraverseStmt(T->getColsExpr()));
994+
TRY_TO(TraverseStmt(T->getUseExpr()));
995+
TRY_TO(TraverseStmt(T->getScopeExpr()));
996+
})
997+
// HLSL Change End
998+
985999
DEF_TRAVERSE_TYPE(ParenType, { TRY_TO(TraverseType(T->getInnerType())); })
9861000

9871001
DEF_TRAVERSE_TYPE(ElaboratedType, {
@@ -1206,6 +1220,20 @@ DEF_TRAVERSE_TYPELOC(ParenType, { TRY_TO(TraverseTypeLoc(TL.getInnerLoc())); })
12061220
DEF_TRAVERSE_TYPELOC(AttributedType,
12071221
{ TRY_TO(TraverseTypeLoc(TL.getModifiedLoc())); })
12081222

1223+
// HLSL Change Start
1224+
DEF_TRAVERSE_TYPELOC(AttributedLinAlgMatrixType, {
1225+
TRY_TO(TraverseType(TL.getTypePtr()->getWrappedType()));
1226+
})
1227+
DEF_TRAVERSE_TYPELOC(DependentAttributedLinAlgMatrixType, {
1228+
TRY_TO(TraverseType(TL.getTypePtr()->getWrappedType()));
1229+
TRY_TO(TraverseStmt(TL.getTypePtr()->getComponentTyExpr()));
1230+
TRY_TO(TraverseStmt(TL.getTypePtr()->getRowsExpr()));
1231+
TRY_TO(TraverseStmt(TL.getTypePtr()->getColsExpr()));
1232+
TRY_TO(TraverseStmt(TL.getTypePtr()->getUseExpr()));
1233+
TRY_TO(TraverseStmt(TL.getTypePtr()->getScopeExpr()));
1234+
})
1235+
// HLSL Change End
1236+
12091237
DEF_TRAVERSE_TYPELOC(ElaboratedType, {
12101238
if (TL.getQualifierLoc()) {
12111239
TRY_TO(TraverseNestedNameSpecifierLoc(TL.getQualifierLoc()));

tools/clang/include/clang/AST/Type.h

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#ifndef LLVM_CLANG_AST_TYPE_H
1515
#define LLVM_CLANG_AST_TYPE_H
1616

17+
#include "dxc/DXIL/DxilConstants.h"
1718
#include "clang/AST/NestedNameSpecifier.h"
1819
#include "clang/AST/TemplateName.h"
1920
#include "clang/Basic/AddressSpaces.h"
@@ -1699,7 +1700,14 @@ class Type : public ExtQualsTypeCommonBase {
16991700

17001701
bool isOpenCLSpecificType() const; // Any OpenCL specific type
17011702

1703+
// HLSL Change Start
17021704
bool isLinAlgMatrixType() const; // HLSL __builtin_LinAlgMatrix
1705+
bool isAttributedLinAlgMatrixType()
1706+
const; // HLSL attributed __builtin_LinAlgMatrix
1707+
bool isDependentAttributedLinAlgMatrixType()
1708+
const; // HLSL attributed __builtin_LinAlgMatrix with dependent
1709+
// parameters
1710+
// HLSL Change End
17031711

17041712
/// Determines if this type, which must satisfy
17051713
/// isObjCLifetimeType(), is implicitly __unsafe_unretained rather
@@ -3736,6 +3744,108 @@ class AttributedType : public Type, public llvm::FoldingSetNode {
37363744
}
37373745
};
37383746

3747+
// HLSL Change Start
3748+
3749+
class AttributedLinAlgMatrixType : public Type, public llvm::FoldingSetNode {
3750+
friend class ASTContext; // ASTContext creates these
3751+
3752+
QualType WrappedType; // should be __builtin_LinAlgMatrix
3753+
hlsl::DXIL::ComponentType ComponentTy;
3754+
size_t Rows, Cols;
3755+
hlsl::DXIL::MatrixUse Use;
3756+
hlsl::DXIL::MatrixScope Scope;
3757+
3758+
AttributedLinAlgMatrixType(QualType WrappedTy,
3759+
hlsl::DXIL::ComponentType ComponentTy, size_t Rows,
3760+
size_t Cols, hlsl::DXIL::MatrixUse Use,
3761+
hlsl::DXIL::MatrixScope Scope)
3762+
: Type(AttributedLinAlgMatrix, QualType(), /*Dependent*/ false,
3763+
/*InstantiationDependent*/ false, /*VariablyModified*/ false,
3764+
/*ContainsUnexpandedParameterPack*/ false),
3765+
WrappedType(WrappedTy), ComponentTy(ComponentTy), Rows(Rows),
3766+
Cols(Cols), Use(Use), Scope(Scope) {}
3767+
3768+
public:
3769+
QualType getWrappedType() const { return WrappedType; }
3770+
3771+
hlsl::DXIL::ComponentType getComponentType() const { return ComponentTy; }
3772+
size_t getRows() const { return Rows; }
3773+
size_t getCols() const { return Cols; }
3774+
hlsl::DXIL::MatrixUse getUse() const { return Use; }
3775+
hlsl::DXIL::MatrixScope getScope() const { return Scope; }
3776+
3777+
void appendMangledAttributes(llvm::raw_ostream &OS) const;
3778+
3779+
bool isSugared() const { return false; }
3780+
QualType desugar() const { return QualType(this, 0); }
3781+
3782+
void Profile(llvm::FoldingSetNodeID &ID) {
3783+
Profile(ID, WrappedType, ComponentTy, Rows, Cols, Use, Scope);
3784+
}
3785+
3786+
static void Profile(llvm::FoldingSetNodeID &ID, QualType WrappedTy,
3787+
hlsl::DXIL::ComponentType ComponentTy, size_t Rows,
3788+
size_t Cols, hlsl::DXIL::MatrixUse Use,
3789+
hlsl::DXIL::MatrixScope Scope) {
3790+
ID.AddPointer(WrappedTy.getAsOpaquePtr());
3791+
ID.AddInteger(static_cast<uint32_t>(ComponentTy));
3792+
ID.AddInteger(static_cast<uint32_t>(Rows));
3793+
ID.AddInteger(static_cast<uint32_t>(Cols));
3794+
ID.AddInteger(static_cast<uint32_t>(Use));
3795+
ID.AddInteger(static_cast<uint32_t>(Scope));
3796+
}
3797+
3798+
static bool classof(const Type *T) {
3799+
return T->getTypeClass() == AttributedLinAlgMatrix;
3800+
}
3801+
};
3802+
3803+
class DependentAttributedLinAlgMatrixType : public Type,
3804+
public llvm::FoldingSetNode {
3805+
const ASTContext &Context;
3806+
QualType WrappedType; // should be __builtin_LinAlgMatrix
3807+
Expr *ComponentTyExpr;
3808+
Expr *RowsExpr;
3809+
Expr *ColsExpr;
3810+
Expr *UseExpr;
3811+
Expr *ScopeExpr;
3812+
3813+
DependentAttributedLinAlgMatrixType(const ASTContext &Context,
3814+
QualType WrappedType,
3815+
Expr *ComponentTyExpr, Expr *RowsExpr,
3816+
Expr *ColsExpr, Expr *UseExpr,
3817+
Expr *ScopeExpr);
3818+
3819+
friend class ASTContext;
3820+
3821+
public:
3822+
QualType getWrappedType() const { return WrappedType; }
3823+
Expr *getComponentTyExpr() const { return ComponentTyExpr; }
3824+
Expr *getRowsExpr() const { return RowsExpr; }
3825+
Expr *getColsExpr() const { return ColsExpr; }
3826+
Expr *getUseExpr() const { return UseExpr; }
3827+
Expr *getScopeExpr() const { return ScopeExpr; }
3828+
3829+
bool isSugared() const { return false; }
3830+
QualType desugar() const { return QualType(this, 0); }
3831+
3832+
static bool classof(const Type *T) {
3833+
return T->getTypeClass() == DependentAttributedLinAlgMatrix;
3834+
}
3835+
3836+
void Profile(llvm::FoldingSetNodeID &ID) {
3837+
Profile(ID, Context, getWrappedType(), getComponentTyExpr(), getRowsExpr(),
3838+
getColsExpr(), getUseExpr(), getScopeExpr());
3839+
}
3840+
3841+
static void Profile(llvm::FoldingSetNodeID &ID, const ASTContext &Context,
3842+
QualType WrappedType, Expr *ComponentTyExpr,
3843+
Expr *RowsExpr, Expr *ColsExpr, Expr *UseExpr,
3844+
Expr *ScopeExpr);
3845+
};
3846+
3847+
// HLSL Change End
3848+
37393849
class TemplateTypeParmType : public Type, public llvm::FoldingSetNode {
37403850
// Helper data collector for canonical types.
37413851
struct CanonicalTTPTInfo {
@@ -5426,6 +5536,14 @@ inline bool Type::isEventT() const {
54265536
inline bool Type::isLinAlgMatrixType() const {
54275537
return isSpecificBuiltinType(BuiltinType::LinAlgMatrix);
54285538
}
5539+
5540+
inline bool Type::isAttributedLinAlgMatrixType() const {
5541+
return isa<AttributedLinAlgMatrixType>(this);
5542+
}
5543+
5544+
inline bool Type::isDependentAttributedLinAlgMatrixType() const {
5545+
return isa<DependentAttributedLinAlgMatrixType>(this);
5546+
}
54295547
// HLSL Change Ends
54305548

54315549
inline bool Type::isImageType() const {

tools/clang/include/clang/AST/TypeLoc.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,17 @@ class AttributedTypeLoc : public ConcreteTypeLoc<UnqualTypeLoc,
823823
}
824824
};
825825

826+
// HLSL Change Start
827+
class AttributedLinAlgMatrixTypeLoc
828+
: public InheritingConcreteTypeLoc<TypeSpecTypeLoc,
829+
AttributedLinAlgMatrixTypeLoc,
830+
AttributedLinAlgMatrixType> {};
831+
832+
class DependentAttributedLinAlgMatrixTypeLoc
833+
: public InheritingConcreteTypeLoc<TypeSpecTypeLoc,
834+
DependentAttributedLinAlgMatrixTypeLoc,
835+
DependentAttributedLinAlgMatrixType> {};
836+
// HLSL Change End
826837

827838
struct ObjCObjectTypeLocInfo {
828839
SourceLocation TypeArgsLAngleLoc;

tools/clang/include/clang/AST/TypeNodes.def

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ TYPE(Record, TagType)
9292
TYPE(Enum, TagType)
9393
NON_CANONICAL_TYPE(Elaborated, Type)
9494
NON_CANONICAL_TYPE(Attributed, Type)
95+
// HLSL Change Start
96+
TYPE(AttributedLinAlgMatrix, Type)
97+
DEPENDENT_TYPE(DependentAttributedLinAlgMatrix, Type)
98+
// HLSL Change End
9599
DEPENDENT_TYPE(TemplateTypeParm, Type)
96100
NON_CANONICAL_TYPE(SubstTemplateTypeParm, Type)
97101
DEPENDENT_TYPE(SubstTemplateTypeParmPack, Type)

tools/clang/include/clang/Basic/Attr.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def Borland : LangOpt<"Borland">;
233233
def CUDA : LangOpt<"CUDA">;
234234
def COnly : LangOpt<"CPlusPlus", 1>;
235235
def SPIRV : LangOpt<"SPIRV">; // SPIRV Change
236+
def HLSL : LangOpt<"HLSL">; // HLSL Change
236237

237238
// Defines targets for target-specific attributes. The list of strings should
238239
// specify architectures for which the target applies, based off the ArchType
@@ -1228,6 +1229,13 @@ def HLSLUnboundedSparseNodes : InheritableParamAttr {
12281229
let Documentation = [Undocumented];
12291230
}
12301231

1232+
def HLSLLinAlgMatrixAttributes : TypeAttr {
1233+
let Spellings = [CXX11<"", "__LinAlgMatrix_Attributes", 2015>];
1234+
let LangOpts = [HLSL];
1235+
let Args = [ExprArgument<"ComponentTy">, ExprArgument<"M">, ExprArgument<"N">,
1236+
ExprArgument<"Use">, ExprArgument<"Scope">];
1237+
let Documentation = [Undocumented];
1238+
}
12311239
// HLSL Change Ends
12321240

12331241
// SPIRV Change Starts

tools/clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8042,6 +8042,19 @@ def err_hlsl_linalg_matrix_dim_must_be_greater_than_zero: Error<
80428042
def err_hlsl_linalg_matrix_layout_invalid : Error<
80438043
"matrix layout %0 is not valid, must be in the range [%1, %2]">;
80448044

8045+
// SM 6.10 Linear Algebra Operations
8046+
def err_hlsl_linalg_matrix_attribute_arg_not_int_or_enum
8047+
: Error<"argument is not an integer%select{| or enumeration}0">;
8048+
def err_hlsl_linalg_matrix_attribute_arg_not_constant_value
8049+
: Error<"matrix attributes argument %1 is not a constant value">;
8050+
def err_hlsl_linalg_matrix_invalid_enum_attribute_value
8051+
: Error<"matrix attribute %0 has invalid value %1, must be in the range "
8052+
"[%2, %3]">;
8053+
def err_hlsl_linalg_matrix_attribute_on_invalid_type
8054+
: Error<"matrix attributes can only be applied to %0">;
8055+
def err_hlsl_linalg_attributed_matrix_required
8056+
: Error<"argument must be linear algebra matrix type">;
8057+
80458058
def err_hlsl_linalg_mul_muladd_output_vector_size_not_equal_to_matrix_M : Error<
80468059
"output vector length must be equal to Matrix M dimension in a linalg Mul/MulAdd operation">;
80478060
def err_hlsl_linalg_mul_muladd_unpacked_input_vector_size_not_equal_to_matrix_K : Error<

0 commit comments

Comments
 (0)