Skip to content

Commit d831cb4

Browse files
authored
[SM6.10][Bugfix] Fix Size check for input interpreted vector in MultiplyAdd (#8388)
Fixes vector size check for input interpreted vector. Adds tests for MultiplyAdd with odd vector and matrix sizes, and with packed input vector. Fixes #8385
1 parent 4ad9834 commit d831cb4

2 files changed

Lines changed: 67 additions & 2 deletions

File tree

tools/clang/lib/Headers/hlsl/dx/linalg.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,14 @@ template <SIZE_TYPE MVal, SIZE_TYPE NVal> struct DimMN<MVal, NVal, true> {
204204
static const SIZE_TYPE N = MVal;
205205
};
206206

207+
template <ComponentEnum CompTy, SIZE_TYPE PackedComponentCount>
208+
struct ScalarCountFromPackedComponents {
209+
static const SIZE_TYPE ElementsPerScalar =
210+
ComponentTypeTraits<CompTy>::ElementsPerScalar;
211+
static const SIZE_TYPE Value =
212+
(PackedComponentCount + ElementsPerScalar - 1) / ElementsPerScalar;
213+
};
214+
207215
} // namespace __detail
208216

209217
template <ComponentEnum ElementType, uint DimA> struct VectorRef {
@@ -506,7 +514,7 @@ template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
506514
ComponentEnum MatrixDT>
507515
// clang-format off
508516
typename hlsl::enable_if<
509-
InterpretedVector<InputElTy, VecK, InputInterp>::Size == K,
517+
VecK == __detail::ScalarCountFromPackedComponents<InputInterp, K>::Value,
510518
vector<OutputElTy, M> >::type
511519
// clang-format on
512520
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
@@ -542,7 +550,7 @@ template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
542550
ComponentEnum MatrixDT>
543551
// clang-format off
544552
typename hlsl::enable_if<
545-
InterpretedVector<InputElTy, VecK, InputInterp>::Size == K,
553+
VecK == __detail::ScalarCountFromPackedComponents<InputInterp, K>::Value,
546554
vector<OutputElTy, M> >::type
547555
// clang-format on
548556
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,

tools/clang/test/CodeGenDXIL/hlsl/linalg/api/vectors.hlsl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using namespace dx::linalg;
77
using MatrixATy = Matrix<ComponentType::F16, 8, 4, MatrixUse::A, MatrixScope::Thread>;
88
using MatrixAccum_8_8_Ty = Matrix<ComponentType::F16, 8, 8, MatrixUse::Accumulator, MatrixScope::Thread>;
99
using MatrixAccum_8_4_Ty = Matrix<ComponentType::F16, 8, 4, MatrixUse::Accumulator, MatrixScope::Thread>;
10+
using Matrix_7_15_ATy = Matrix<ComponentType::F16, 7, 15, MatrixUse::A, MatrixScope::Thread>;
1011

1112
ByteAddressBuffer BAB : register(t0);
1213

@@ -87,4 +88,60 @@ void main(uint ID : SV_GroupID) {
8788
half3 ThreeF16 = BAB.Load<half3>(256);
8889
InterpretedVector<uint, 1, ComponentEnum::F8_E4M3FN> convertedPacked2 =
8990
Convert<ComponentEnum::F8_E4M3FN, ComponentEnum::F16>(ThreeF16);
91+
92+
// Test MultiplyAdd with odd sizes
93+
//
94+
vector<half, 15> vecH15 = BAB.Load< vector<half, 15> >(168);
95+
vector<half, 7> vecH7 = BAB.Load< vector<half, 7> >(64);
96+
97+
InterpretedVector<half, 15, ComponentEnum::F16> interpVecH15 = MakeInterpretedVector<ComponentEnum::F16>(vecH15);
98+
99+
// CHECK: %[[MAT_7_15:.*]] = call %dx.types.LinAlgMatrixC8M7N15U0S0 @dx.op.linAlgMatrixLoadFromDescriptor.mC8M7N15U0S0(i32 -2147483634,
100+
// CHECK-SAME: %dx.types.Handle %{{[0-9]+}}, i32 0, i32 16, i32 1, i32 128) ; LinAlgMatrixLoadFromDescriptor(handle,offset,stride,layout,align)
101+
Matrix_7_15_ATy Mat_7_15 = Matrix_7_15_ATy::Load<MatrixLayoutEnum::ColMajor>(BAB, 0, 16);
102+
103+
// CHECK: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC8M7N15U0S0.v15f16.v7f16(i32 -2147483622,
104+
// CHECK-SAME: %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]], i1 true, <15 x half> %{{[0-9]+}}, i32 8, <7 x half> %{{[0-9]+}}, i32 8)
105+
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
106+
vector<half, 7> vec7 = MultiplyAdd<half>(Mat_7_15, vecH15, vecH7);
107+
108+
// CHECK: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC8M7N15U0S0.v15f16.v7f16(i32 -2147483622, %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]],
109+
// CHECK-SAME; i1 true, <15 x half> %{{[0-9]+}}, i32 8, <7 x half> %{{[0-9]+}}, i32 8)
110+
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
111+
vector<half, 7> vec8 = MultiplyAdd<half>(Mat_7_15, interpVecH15, vecH7);
112+
113+
// CHECK: %[[LOAD1:.*]] = call %dx.types.ResRet.v7f16 @dx.op.rawBufferVectorLoad.v7f16(i32 303, %dx.types.Handle %{{[0-9]+}}, i32 512, i32 undef, i32 2)
114+
// CHECK-SAME: ; RawBufferVectorLoad(buf,index,elementOffset,alignment)
115+
// CHECK: %[[MEM_BIAS1:.*]] = extractvalue %dx.types.ResRet.v7f16 %[[LOAD1]], 0
116+
// CHECK: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC8M7N15U0S0.v15f16.v7f16(i32 -2147483622,
117+
// CHECK-SAME: %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]], i1 true, <15 x half> %29, i32 8, <7 x half> %37, i32 8)
118+
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
119+
VectorRef<ComponentType::F16, 7> memBias7 = {BAB, 512};
120+
vector<half, 7> vec9 = MultiplyAdd<half>(Mat_7_15, vecH15, memBias7);
121+
122+
// CHECK: %[[LOAD2:.*]] = call %dx.types.ResRet.v7f16 @dx.op.rawBufferVectorLoad.v7f16(i32 303, %dx.types.Handle %{{[0-9]+}}, i32 512, i32 undef, i32 2)
123+
// CHECK-SAME: ; RawBufferVectorLoad(buf,index,elementOffset,alignment)
124+
// CHECK: %[[MEM_BIAS2:.*]] = extractvalue %dx.types.ResRet.v7f16 %[[LOAD2]], 0
125+
// CHECK-NEXT: %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]], i1 true, <15 x half> %{{[0-9]+}}, i32 8, <7 x half> %[[MEM_BIAS2]], i32 8)
126+
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
127+
vector<half, 7> vec10 = MultiplyAdd<half>(Mat_7_15, interpVecH15, memBias7);
128+
129+
// Test MultiplyAdd with packed input vector
130+
//
131+
// CHECK: %[[INTERP_VEC_H15_PACKED:.*]] = call <4 x i32> @dx.op.linAlgConvert.v4i32.v15f16(i32 -2147483618,
132+
// CHECK-SAME: <15 x half> %{{[0-9]+}}, i32 8, i32 21) ; LinAlgConvert(inputVector,inputInterpretation,outputInterpretation)
133+
InterpretedVector<uint, 4, ComponentEnum::F8_E4M3FN> interpVecH15Packed = Convert<ComponentEnum::F8_E4M3FN, ComponentEnum::F16>(vecH15);
134+
135+
// CHECK: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC8M7N15U0S0.v4i32.v7f16(i32 -2147483622,
136+
// CHECK-SAME: %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]], i1 true, <4 x i32> %43, i32 21, <7 x half> %31, i32 8)
137+
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
138+
vector<half, 7> vec11 = MultiplyAdd<half>(Mat_7_15, interpVecH15Packed, vecH7);
139+
140+
// CHECK: %[[LOAD3:.+]] = call %dx.types.ResRet.v7f16 @dx.op.rawBufferVectorLoad.v7f16(i32 303, %dx.types.Handle %45, i32 512, i32 undef, i32 2)
141+
// CHECK-SAME: ; RawBufferVectorLoad(buf,index,elementOffset,alignment)
142+
// CHECK-NEXT: %[[MEM_BIAS3:.*]] = extractvalue %dx.types.ResRet.v7f16 %46, 0
143+
// CHECK-NEXT: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC8M7N15U0S0.v4i32.v7f16(i32 -2147483622,
144+
// CHECK-SAME: %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]], i1 true, <4 x i32> %[[INTERP_VEC_H15_PACKED]], i32 21, <7 x half> %[[MEM_BIAS3]], i32 8)
145+
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
146+
vector<half, 7> vec12 = MultiplyAdd<half>(Mat_7_15, interpVecH15Packed, memBias7);
90147
}

0 commit comments

Comments
 (0)