Skip to content

Commit ff6b0b9

Browse files
authored
[SM6.10][Bugfix] Update built-in definitions to allow vectors of different sizes (#8347)
Applies to `Convert` and `OuterProduct` functions. Fixes #8344
1 parent 11e3590 commit ff6b0b9

2 files changed

Lines changed: 19 additions & 7 deletions

File tree

tools/clang/test/CodeGenDXIL/hlsl/linalg/api/vectors.hlsl

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,23 @@
55
using namespace dx::linalg;
66

77
using MatrixATy = Matrix<ComponentType::F16, 8, 4, MatrixUse::A, MatrixScope::Thread>;
8-
using MatrixAccumTy = Matrix<ComponentType::F16, 8, 8, MatrixUse::Accumulator, MatrixScope::Thread>;
8+
using MatrixAccum_8_8_Ty = Matrix<ComponentType::F16, 8, 8, MatrixUse::Accumulator, MatrixScope::Thread>;
9+
using MatrixAccum_8_4_Ty = Matrix<ComponentType::F16, 8, 4, MatrixUse::Accumulator, MatrixScope::Thread>;
910

1011
ByteAddressBuffer BAB : register(t0);
1112

1213
[numthreads(4, 4, 4)]
1314
void main(uint ID : SV_GroupID) {
1415

1516
// CHECK: %[[MAT1:.*]] = call %dx.types.LinAlgMatrixC8M8N4U0S0 @dx.op.linAlgMatrixLoadFromDescriptor.mC8M8N4U0S0(
16-
// CHECK-SAME: i32 -2147483634, %dx.types.Handle %2, i32 0, i32 8, i32 1, i32 2)
17+
// CHECK-SAME: i32 -2147483634, %dx.types.Handle %{{[0-9]+}}, i32 0, i32 8, i32 1, i32 2)
1718
// CHECK-SAME: ; LinAlgMatrixLoadFromDescriptor(handle,offset,stride,layout,align)
1819
MatrixATy Mat1 = MatrixATy::Load<MatrixLayoutEnum::ColMajor>(BAB, 0, 8);
1920

2021
vector<half, 4> vec1 = 10.3f;
2122

2223
// CHECK: %[[VEC2:.*]] = call <8 x half> @dx.op.linAlgMatVecMul.v8f16.mC8M8N4U0S0.v4f16(i32 -2147483623,
23-
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N4U0S0 %3, i1 true, <4 x half> <half 0xH4926, half 0xH4926, half 0xH4926,
24+
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N4U0S0 %[[MAT1]], i1 true, <4 x half> <half 0xH4926, half 0xH4926, half 0xH4926,
2425
// CHECK-SAME: half 0xH4926>, i32 8) ; LinAlgMatVecMul(matrix,isOutputSigned,inputVector,interpretation)
2526
vector<half, 8> vec2 = Multiply<half>(Mat1, vec1);
2627

@@ -61,13 +62,24 @@ void main(uint ID : SV_GroupID) {
6162
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
6263
vector<half, 8> vec6 = MultiplyAdd<half>(Mat1, interpVec2, memBias);
6364

64-
// CHECK: %[[ACCUM:.*]] = call %dx.types.LinAlgMatrixC8M8N8U2S0
65+
// CHECK: %[[ACCUM1:.*]] = call %dx.types.LinAlgMatrixC8M8N8U2S0
6566
// CHECK-SAME: @dx.op.linAlgMatrixOuterProduct.mC8M8N8U2S0.v8f16.v8f16(i32 -2147483619,
6667
// CHECK-SAME: <8 x half> %[[VEC5]], <8 x half> %[[VEC6]]) ; LinAlgMatrixOuterProduct(vectorA,vectorB)
67-
MatrixAccumTy AccumMatrix = OuterProduct<ComponentType::F16>(vec5, vec6);
68+
MatrixAccum_8_8_Ty AccumMatrix1 = OuterProduct<ComponentType::F16>(vec5, vec6);
69+
70+
// CHECK: %[[ACCUM2:.*]] = call %dx.types.LinAlgMatrixC8M8N4U2S0 @dx.op.linAlgMatrixOuterProduct.mC8M8N4U2S0.v8f16.v4f16(
71+
// CHECK-SAME: i32 -2147483619, <8 x half> %[[VEC5]], <4 x half> %[[VEC20]]) ; LinAlgMatrixOuterProduct(vectorA,vectorB)
72+
MatrixAccum_8_4_Ty AccumMatrix2 = OuterProduct<ComponentType::F16>(vec5, vec20);
6873

6974
// CHECK: %[[CONV_VEC:.*]] = call <8 x float> @dx.op.linAlgConvert.v8f32.v8f16(i32 -2147483618,
7075
// CHECK-SAME: <8 x half> %[[VEC6]], i32 8, i32 9) ; LinAlgConvert(inputVector,inputInterpretation,outputInterpretation)
7176
InterpretedVector<float, 8, ComponentType::F32> convertedVec;
7277
convertedVec = Convert<ComponentType::F32, ComponentType::F16>(vec6);
78+
79+
// CHECK: call <4 x i32> @dx.op.linAlgConvert.v4i32.v16f16(i32 -2147483618, <16 x half> %21, i32 8, i32 21)
80+
// CHECK: ; LinAlgConvert(inputVector,inputInterpretation,outputInterpretation)
81+
typedef vector<half, 16> half16;
82+
half16 srcF16 = BAB.Load<half16>(128);
83+
InterpretedVector<uint, 4, ComponentEnum::F8_E4M3FN> convertedPacked = Convert<ComponentEnum::F8_E4M3FN, ComponentEnum::F16>(srcF16);
84+
7385
}

utils/hct/gen_intrin_main.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -408,8 +408,8 @@ void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiply(out numeric<c> ret, i
408408
void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiplyAdd(out numeric<c> ret, in LinAlgMatrix mat, in bool isOutputSigned, in numeric<c2> input, in uint inputInterp, in numeric<c> bias, in uint biasInterp);
409409
void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToDescriptor(in LinAlgMatrix matrix, in RWByteAddressBuffer buf, in uint offset, in uint stride, in uint layout, in uint align);
410410
void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToMemory(in LinAlgMatrix matrix, groupshared numeric[] memory, in uint offset, in uint stride, in uint layout);
411-
void [[min_sm=6.10]] __builtin_LinAlg_MatrixOuterProduct(out LinAlgMatrix ret, in numeric<> vecA, in numeric<> vecB);
412-
void [[min_sm=6.10]] __builtin_LinAlg_Convert(out numeric<> ret, in numeric<> vec, in uint input_interp, in uint output_interp);
411+
void [[min_sm=6.10]] __builtin_LinAlg_MatrixOuterProduct(out LinAlgMatrix ret, in numeric<c> vecA, in numeric<c2> vecB);
412+
void [[min_sm=6.10]] __builtin_LinAlg_Convert(out numeric<c> ret, in numeric<c2> vec, in uint input_interp, in uint output_interp);
413413

414414
} namespace
415415

0 commit comments

Comments
 (0)