Skip to content

Commit b48703b

Browse files
authored
[SM6.10][Bugfix] Update Matrix::Cast definition to switch matrix sizes on transpose (#8368)
Spec updated in microsoft/hlsl-specs#842
1 parent 71bf911 commit b48703b

2 files changed

Lines changed: 31 additions & 10 deletions

File tree

tools/clang/lib/Headers/hlsl/dx/linalg.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,16 @@ template <ComponentEnum DstTy, ComponentEnum SrcTy, int SrcN> struct DstN {
194194
ComponentTypeTraits<DstTy>::ElementsPerScalar;
195195
};
196196

197+
template <SIZE_TYPE MVal, SIZE_TYPE NVal, bool Transposed> struct DimMN {
198+
static const SIZE_TYPE M = MVal;
199+
static const SIZE_TYPE N = NVal;
200+
};
201+
202+
template <SIZE_TYPE MVal, SIZE_TYPE NVal> struct DimMN<MVal, NVal, true> {
203+
static const SIZE_TYPE M = NVal;
204+
static const SIZE_TYPE N = MVal;
205+
};
206+
197207
} // namespace __detail
198208

199209
template <ComponentEnum ElementType, uint DimA> struct VectorRef {
@@ -242,8 +252,12 @@ class Matrix {
242252

243253
template <ComponentEnum NewCompTy, MatrixUseEnum NewUse = Use,
244254
bool Transpose = false>
245-
Matrix<NewCompTy, M, N, NewUse, Scope> Cast() {
246-
Matrix<NewCompTy, M, N, NewUse, Scope> Result;
255+
Matrix<NewCompTy, __detail::DimMN<M, N, Transpose>::M,
256+
__detail::DimMN<M, N, Transpose>::N, NewUse, Scope>
257+
Cast() {
258+
Matrix<NewCompTy, __detail::DimMN<M, N, Transpose>::M,
259+
__detail::DimMN<M, N, Transpose>::N, NewUse, Scope>
260+
Result;
247261
__builtin_LinAlg_CopyConvertMatrix(Result.__handle, __handle, Transpose);
248262
return Result;
249263
}

tools/clang/test/CodeGenDXIL/hlsl/linalg/api/matrix-class.hlsl

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,15 @@ using namespace dx::linalg;
66

77
using MatrixATy = Matrix<ComponentType::F32, 4, 4, MatrixUse::A, MatrixScope::Wave>;
88
using MatrixBTy = Matrix<ComponentType::F32, 4, 4, MatrixUse::B, MatrixScope::Wave>;
9-
using MatrixBTyInt = Matrix<ComponentType::I32, 4, 4, MatrixUse::B, MatrixScope::Wave>;
109
using MatrixAccumTy = Matrix<ComponentType::F32, 4, 4, MatrixUse::Accumulator, MatrixScope::Wave>;
1110
using TSMatrixATy = Matrix<ComponentType::F32, 4, 4, MatrixUse::A, MatrixScope::Thread>;
1211
using TSMatrixAccumTy = Matrix<ComponentType::F32, 4, 4, MatrixUse::Accumulator, MatrixScope::Thread>;
1312

13+
using Matrix48TyFloat = Matrix<ComponentType::F32, 4, 8, MatrixUse::A, MatrixScope::Wave>;
14+
using Matrix48TyInt = Matrix<ComponentType::I32, 4, 8, MatrixUse::A, MatrixScope::Wave>;
15+
using Matrix84TyInt = Matrix<ComponentType::I32, 8, 4, MatrixUse::A, MatrixScope::Wave>;
16+
17+
1418
ByteAddressBuffer BAB : register(t0);
1519
RWByteAddressBuffer RWBAB : register(u0);
1620
groupshared float SharedArr[256];
@@ -34,16 +38,19 @@ void main(uint ID : SV_GroupID)
3438

3539
// Matrix::Cast
3640
//
37-
// CHECK: call %dx.types.LinAlgMatrixC4M4N4U1S1 @dx.op.linAlgCopyConvertMatrix.mC4M4N4U1S1.mC9M4N4U0S1(
38-
// CHECK-SAME: i32 -2147483635, %dx.types.LinAlgMatrixC9M4N4U0S1 %[[MATA1]], i1 false)
41+
// CHECK: %[[MAT48F:.*]] = call %dx.types.LinAlgMatrixC9M4N8U0S1 @dx.op.linAlgFillMatrix.mC9M4N8U0S1.f32(
42+
// CHECK-SAME: i32 -2147483636, float 3.000000e+00) ; LinAlgFillMatrix(value)
43+
44+
// CHECK: call %dx.types.LinAlgMatrixC4M4N8U0S1 @dx.op.linAlgCopyConvertMatrix.mC4M4N8U0S1.mC9M4N8U0S1(
45+
// CHECK-SAME: i32 -2147483635, %dx.types.LinAlgMatrixC9M4N8U0S1 %[[MAT48F]], i1 false)
3946
// CHECK-SAME: ; LinAlgCopyConvertMatrix(srcMatrix,transpose)
40-
MatrixBTyInt MatBInt1 = MatA1.Cast<ComponentType::I32, MatrixUse::B>();
47+
Matrix48TyFloat Mat48F = Matrix48TyFloat::Splat(3.0f);
48+
Matrix48TyInt Mat48I = Mat48F.Cast<ComponentType::I32>();
4149

42-
// CHECK: call %dx.types.LinAlgMatrixC4M4N4U1S1 @dx.op.linAlgCopyConvertMatrix.mC4M4N4U1S1.mC9M4N4U1S1(
43-
// CHECK-SAME: i32 -2147483635, %dx.types.LinAlgMatrixC9M4N4U1S1 %[[MATB1]], i1 true)
50+
// CHECK: call %dx.types.LinAlgMatrixC4M8N4U0S1 @dx.op.linAlgCopyConvertMatrix.mC4M8N4U0S1.mC9M4N8U0S1(
51+
// CHECK-SAME: i32 -2147483635, %dx.types.LinAlgMatrixC9M4N8U0S1 %[[MAT48F]], i1 true)
4452
// CHECK-SAME: ; LinAlgCopyConvertMatrix(srcMatrix,transpose)
45-
MatrixBTyInt MatBInt2;
46-
MatBInt2 = MatB1.Cast<ComponentType::I32, MatrixUse::B, true>();
53+
Matrix84TyInt Mat84I = Mat48F.Cast<ComponentType::I32, MatrixUse::A, true>();
4754

4855
// Matrix::Load from ByteAddressBuffer
4956
//

0 commit comments

Comments
 (0)