Skip to content

Commit 9e0d816

Browse files
authored
[SM6.10] Update SFINAE conditions in linalg.h (#8320)
Update SFINAE conditions in `linalg.h` based on the recent spec update (microsoft/hlsl-specs#833) and update tests. - thread-scope `Matrix::Load` to require `A` matrix - thread-scope `Matrix::InterlockedAccumulate` to require `Accumulator` matrix - groupshared `Matrix::InterlockedAccumulate` to require `Wave` scope Part of #7839
1 parent ccba9f8 commit 9e0d816

2 files changed

Lines changed: 25 additions & 13 deletions

File tree

tools/clang/lib/Headers/hlsl/dx/linalg.h

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -323,10 +323,12 @@ class Matrix {
323323
Stride, Layout, Align);
324324
}
325325

326-
template <typename T, MatrixUseEnum UseLocal = Use, SIZE_TYPE Size>
326+
template <typename T, MatrixUseEnum UseLocal = Use,
327+
MatrixScopeEnum ScopeLocal = Scope, SIZE_TYPE Size>
327328
typename hlsl::enable_if<
328329
hlsl::is_arithmetic<T>::value && Use == MatrixUse::Accumulator &&
329-
UseLocal == Use && (M * N / ElementsPerScalar <= Size),
330+
UseLocal == Use && (M * N / ElementsPerScalar <= Size) &&
331+
Scope == MatrixScope::Wave && ScopeLocal == Scope,
330332
void>::type
331333
InterlockedAccumulate(groupshared T Arr[Size], uint StartIdx, uint Stride,
332334
MatrixLayoutEnum Layout) {
@@ -370,18 +372,23 @@ class Matrix<ComponentTy, M, N, Use, MatrixScope::Thread> {
370372
ComponentTy, M, N, Use, MatrixScope::Thread)]];
371373
HandleT __handle;
372374

373-
template <MatrixLayoutEnum Layout>
374-
static Matrix Load(ByteAddressBuffer Res, uint StartOffset, uint Stride,
375-
uint Align = sizeof(ElementType)) {
375+
template <MatrixLayoutEnum Layout, MatrixUseEnum UseLocal = Use>
376+
static typename hlsl::enable_if<Use == MatrixUse::A && UseLocal == Use,
377+
Matrix>::type
378+
Load(ByteAddressBuffer Res, uint StartOffset, uint Stride,
379+
uint Align = sizeof(ElementType)) {
376380
Matrix Result;
377381
__builtin_LinAlg_MatrixLoadFromDescriptor(Result.__handle, Res, StartOffset,
378382
Stride, Layout, Align);
379383
return Result;
380384
}
381385

382-
void InterlockedAccumulate(RWByteAddressBuffer Res, uint StartOffset,
383-
uint Stride, MatrixLayoutEnum Layout,
384-
uint Align = sizeof(ElementType)) {
386+
template <MatrixUseEnum UseLocal = Use>
387+
typename hlsl::enable_if<Use == MatrixUse::Accumulator && UseLocal == Use,
388+
void>::type
389+
InterlockedAccumulate(RWByteAddressBuffer Res, uint StartOffset, uint Stride,
390+
MatrixLayoutEnum Layout,
391+
uint Align = sizeof(ElementType)) {
385392
__builtin_LinAlg_MatrixAccumulateToDescriptor(__handle, Res, StartOffset,
386393
Stride, Layout, Align);
387394
}

tools/clang/test/CodeGenDXIL/hlsl/linalg/api/matrix-class.hlsl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using MatrixBTy = Matrix<ComponentType::F32, 4, 4, MatrixUse::B, MatrixScope::Wa
99
using MatrixBTyInt = Matrix<ComponentType::I32, 4, 4, MatrixUse::B, MatrixScope::Wave>;
1010
using MatrixAccumTy = Matrix<ComponentType::F32, 4, 4, MatrixUse::Accumulator, MatrixScope::Wave>;
1111
using TSMatrixATy = Matrix<ComponentType::F32, 4, 4, MatrixUse::A, MatrixScope::Thread>;
12+
using TSMatrixAccumTy = Matrix<ComponentType::F32, 4, 4, MatrixUse::Accumulator, MatrixScope::Thread>;
1213

1314
ByteAddressBuffer BAB : register(t0);
1415
RWByteAddressBuffer RWBAB : register(u0);
@@ -128,7 +129,7 @@ void main(uint ID : SV_GroupID)
128129
// Matrix::InterlockedAccumulate to groupshared memory
129130
//
130131
// CHECK: call void @dx.op.linAlgMatrixAccumulateToMemory.mC9M4N4U2S1.f32(i32 -2147483620,
131-
// CHECK-SAME: %dx.types.LinAlgMatrixC9M4N4U2S1 %18,
132+
// CHECK-SAME: %dx.types.LinAlgMatrixC9M4N4U2S1 %[[ACCUM0]],
132133
// CHECK-SAME: float addrspace(3)* getelementptr inbounds ([256 x float],
133134
// CHECK-SAME: [256 x float] addrspace(3)* @"\01?SharedArr@@3PAMA", i32 0, i32 0), i32 0, i32 16, i32 1)
134135
// CHECK-SAME: ; LinAlgMatrixAccumulateToMemory(matrix,memory,offset,stride,layout)
@@ -166,16 +167,20 @@ void main(uint ID : SV_GroupID)
166167
// Matrix::Load for thread-scope matrix
167168
//
168169
// CHECK: %[[TSMATA:.*]] = call %dx.types.LinAlgMatrixC9M4N4U0S0 @dx.op.linAlgMatrixLoadFromDescriptor.mC9M4N4U0S0(
169-
// CHECK-SAME: i32 -2147483634, %dx.types.Handle %24, i32 0, i32 16, i32 1, i32 4)
170+
// CHECK-SAME: i32 -2147483634, %dx.types.Handle %{{[0-9]+}}, i32 0, i32 16, i32 1, i32 4)
170171
// CHECK-SAME: ; LinAlgMatrixLoadFromDescriptor(handle,offset,stride,layout,align)
171172
TSMatrixATy TSMatA = TSMatrixATy::Load<MatrixLayoutEnum::ColMajor>(BAB, 0, 16);
172173

173174
// Matrix::InterlockedAccumulate for thread-scope matrix
174175
//
175-
// CHECK: call void @dx.op.linAlgMatrixAccumulateToDescriptor.mC9M4N4U0S0(i32 -2147483621,
176-
// CHECK-SAME: %dx.types.LinAlgMatrixC9M4N4U0S0 %25, %dx.types.Handle %26, i32 0, i32 16, i32 1, i32 4)
176+
// CHECK: %[[TSACCUM:.*]] = call %dx.types.LinAlgMatrixC9M4N4U2S0 @dx.op.linAlgMatrixOuterProduct.mC9M4N4U2S0.v4f32.v4f32
177+
// CHECK: call void @dx.op.linAlgMatrixAccumulateToDescriptor.mC9M4N4U2S0(i32 -2147483621,
178+
// CHECK-SAME: %dx.types.LinAlgMatrixC9M4N4U2S0 %[[TSACCUM]], %dx.types.Handle %{{[0-9]+}}, i32 0, i32 16, i32 1, i32 4)
177179
// CHECK-SAME: ; LinAlgMatrixAccumulateToDescriptor(matrix,handle,offset,stride,layout,align)
178-
TSMatA.InterlockedAccumulate(RWBAB, 0, 16, MatrixLayoutEnum::ColMajor);
180+
vector<float, 4> vec1 = 1.0f;
181+
vector<float, 4> vec2 = 2.0f;
182+
TSMatrixAccumTy TSMatAccum = OuterProduct<ComponentType::F32>(vec1, vec2);
183+
TSMatAccum.InterlockedAccumulate(RWBAB, 0, 16, MatrixLayoutEnum::ColMajor);
179184

180185
// CHECK: call i32 @dx.op.linAlgMatrixQueryAccumulatorLayout(i32 -2147483626) ; LinAlgMatrixQueryAccumulatorLayout()
181186
MatrixUseEnum layout = AccumulatorLayout();

0 commit comments

Comments
 (0)