Skip to content

Commit 0131a62

Browse files
authored
[SM6.10][Bugfix] Update vector sizes in linalg.h to match column-vector multiplication (#8338)
Matrix-vector APIs should be aligned with coopvec, meaning the multiplication would be matrix x column-based vector. Also fixes an issue with `__builtin_LinAlg_MatrixVectorMultiply*` built-ins that did not allow vectors of different sizes. Fixes #8335
1 parent 0a3545a commit 0131a62

3 files changed

Lines changed: 43 additions & 41 deletions

File tree

tools/clang/lib/Headers/hlsl/dx/linalg.h

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ Matrix<CompTy, M, N, MatrixUse::Accumulator, MatrixScope::ThreadGroup> Multiply(
465465
template <typename OutputElTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE K,
466466
ComponentEnum MatrixDT>
467467
// clang-format off
468-
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value, vector<OutputElTy, K> >::type
468+
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value, vector<OutputElTy, M> >::type
469469
// clang-format on
470470
Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
471471
vector<InputElTy, K> Vec) {
@@ -479,29 +479,29 @@ Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
479479
template <typename OutputElTy, typename InputElTy, typename BiasElTy,
480480
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
481481
// clang-format off
482-
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value, vector<OutputElTy, K> >::type
482+
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value, vector<OutputElTy, M> >::type
483483
// clang-format on
484484
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
485-
vector<InputElTy, M> Vec, vector<BiasElTy, K> Bias) {
486-
vector<OutputElTy, K> Result;
485+
vector<InputElTy, K> Vec, vector<BiasElTy, M> Bias) {
486+
vector<OutputElTy, M> Result;
487487
__builtin_LinAlg_MatrixVectorMultiplyAdd(Result, MatrixA.__handle,
488488
hlsl::is_signed<OutputElTy>::value,
489489
Vec, MatrixDT, Bias, MatrixDT);
490490
return Result;
491491
}
492492

493493
template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
494-
typename BiasElTy, SIZE_TYPE M, SIZE_TYPE VecM, SIZE_TYPE K,
494+
typename BiasElTy, SIZE_TYPE M, SIZE_TYPE VecK, SIZE_TYPE K,
495495
ComponentEnum MatrixDT>
496496
// clang-format off
497497
typename hlsl::enable_if<
498-
InterpretedVector<InputElTy, VecM, InputInterp>::Size == M,
499-
vector<OutputElTy, K> >::type
498+
InterpretedVector<InputElTy, VecK, InputInterp>::Size == K,
499+
vector<OutputElTy, M> >::type
500500
// clang-format on
501501
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
502-
InterpretedVector<InputElTy, VecM, InputInterp> InterpVec,
503-
vector<BiasElTy, K> Bias) {
504-
vector<OutputElTy, K> Result;
502+
InterpretedVector<InputElTy, VecK, InputInterp> InterpVec,
503+
vector<BiasElTy, M> Bias) {
504+
vector<OutputElTy, M> Result;
505505
__builtin_LinAlg_MatrixVectorMultiplyAdd(
506506
Result, MatrixA.__handle, hlsl::is_signed<OutputElTy>::value,
507507
InterpVec.Data, InterpVec.Interpretation, Bias, MatrixDT);
@@ -512,35 +512,35 @@ template <typename OutputElTy, typename InputElTy, ComponentEnum BiasElTy,
512512
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
513513
// clang-format off
514514
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value,
515-
vector<OutputElTy, K> >::type
515+
vector<OutputElTy, M> >::type
516516
// clang-format on
517517
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
518-
vector<InputElTy, M> Vec, VectorRef<BiasElTy, K> BiasRef) {
518+
vector<InputElTy, K> Vec, VectorRef<BiasElTy, M> BiasRef) {
519519
using BiasVecTy =
520-
vector<typename __detail::ComponentTypeTraits<BiasElTy>::Type, K>;
520+
vector<typename __detail::ComponentTypeTraits<BiasElTy>::Type, M>;
521521
BiasVecTy BiasVec = BiasRef.Buf.template Load<BiasVecTy>(BiasRef.Offset);
522-
vector<OutputElTy, K> Result;
522+
vector<OutputElTy, M> Result;
523523
__builtin_LinAlg_MatrixVectorMultiplyAdd(Result, MatrixA.__handle,
524524
hlsl::is_signed<OutputElTy>::value,
525525
Vec, MatrixDT, BiasVec, BiasElTy);
526526
return Result;
527527
}
528528

529529
template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
530-
ComponentEnum BiasElTy, SIZE_TYPE M, SIZE_TYPE VecM, SIZE_TYPE K,
530+
ComponentEnum BiasElTy, SIZE_TYPE M, SIZE_TYPE VecK, SIZE_TYPE K,
531531
ComponentEnum MatrixDT>
532532
// clang-format off
533533
typename hlsl::enable_if<
534-
InterpretedVector<InputElTy, VecM, InputInterp>::Size == M,
535-
vector<OutputElTy, K> >::type
534+
InterpretedVector<InputElTy, VecK, InputInterp>::Size == K,
535+
vector<OutputElTy, M> >::type
536536
// clang-format on
537537
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
538-
InterpretedVector<InputElTy, VecM, InputInterp> InterpVec,
539-
VectorRef<BiasElTy, K> BiasRef) {
538+
InterpretedVector<InputElTy, VecK, InputInterp> InterpVec,
539+
VectorRef<BiasElTy, M> BiasRef) {
540540
using BiasVecTy =
541-
vector<typename __detail::ComponentTypeTraits<BiasElTy>::Type, K>;
541+
vector<typename __detail::ComponentTypeTraits<BiasElTy>::Type, M>;
542542
BiasVecTy BiasVec = BiasRef.Buf.template Load<BiasVecTy>(BiasRef.Offset);
543-
vector<OutputElTy, K> Result;
543+
vector<OutputElTy, M> Result;
544544
__builtin_LinAlg_MatrixVectorMultiplyAdd(
545545
Result, MatrixA.__handle, hlsl::is_signed<OutputElTy>::value,
546546
InterpVec.Data, InterpVec.Interpretation, BiasVec, BiasElTy);

tools/clang/test/CodeGenDXIL/hlsl/linalg/api/vectors.hlsl

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,58 +4,60 @@
44
#include <dx/linalg.h>
55
using namespace dx::linalg;
66

7-
using MatrixATy = Matrix<ComponentType::F16, 8, 8, MatrixUse::A, MatrixScope::Thread>;
7+
using MatrixATy = Matrix<ComponentType::F16, 8, 4, MatrixUse::A, MatrixScope::Thread>;
88
using MatrixAccumTy = Matrix<ComponentType::F16, 8, 8, MatrixUse::Accumulator, MatrixScope::Thread>;
99

1010
ByteAddressBuffer BAB : register(t0);
1111

1212
[numthreads(4, 4, 4)]
1313
void main(uint ID : SV_GroupID) {
1414

15-
// CHECK: %[[MAT1:.*]] = call %dx.types.LinAlgMatrixC8M8N8U0S0 @dx.op.linAlgMatrixLoadFromDescriptor.mC8M8N8U0S0(
15+
// CHECK: %[[MAT1:.*]] = call %dx.types.LinAlgMatrixC8M8N4U0S0 @dx.op.linAlgMatrixLoadFromDescriptor.mC8M8N4U0S0(
1616
// CHECK-SAME: i32 -2147483634, %dx.types.Handle %2, i32 0, i32 8, i32 1, i32 2)
1717
// CHECK-SAME: ; LinAlgMatrixLoadFromDescriptor(handle,offset,stride,layout,align)
1818
MatrixATy Mat1 = MatrixATy::Load<MatrixLayoutEnum::ColMajor>(BAB, 0, 8);
1919

20-
vector<half, 8> vec1 = 10.3f;
20+
vector<half, 4> vec1 = 10.3f;
2121

22-
// CHECK: %[[VEC2:.*]] = call <8 x half> @dx.op.linAlgMatVecMul.v8f16.mC8M8N8U0S0.v8f16(i32 -2147483623,
23-
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N8U0S0 %3, i1 true, <8 x half> <half 0xH4926, half 0xH4926, half 0xH4926,
24-
// CHECK-SAME: half 0xH4926, half 0xH4926, half 0xH4926, half 0xH4926, half 0xH4926>, i32 8)
25-
// CHECK-SAME: ; LinAlgMatVecMul(matrix,isOutputSigned,inputVector,interpretation)
22+
// CHECK: %[[VEC2:.*]] = call <8 x half> @dx.op.linAlgMatVecMul.v8f16.mC8M8N4U0S0.v4f16(i32 -2147483623,
23+
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N4U0S0 %3, i1 true, <4 x half> <half 0xH4926, half 0xH4926, half 0xH4926,
24+
// CHECK-SAME: half 0xH4926>, i32 8) ; LinAlgMatVecMul(matrix,isOutputSigned,inputVector,interpretation)
2625
vector<half, 8> vec2 = Multiply<half>(Mat1, vec1);
2726

28-
// CHECK: %[[VEC3:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N8U0S0.v8f16.v8f16(i32 -2147483622,
29-
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N8U0S0 %[[MAT1]], i1 true, <8 x half> <half 0xH4926, half 0xH4926, half 0xH4926,
30-
// CHECK-SAME: half 0xH4926, half 0xH4926, half 0xH4926, half 0xH4926, half 0xH4926>, i32 8, <8 x half> %[[VEC2]], i32 8)
27+
// CHECK: %[[VEC3:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N4U0S0.v4f16.v8f16(i32 -2147483622,
28+
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N4U0S0 %[[MAT1]], i1 true, <4 x half> <half 0xH4926, half 0xH4926, half 0xH4926,
29+
// CHECK-SAME: half 0xH4926>, i32 8, <8 x half> %[[VEC2]], i32 8)
3130
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
3231
vector<half, 8> vec3 = MultiplyAdd<half>(Mat1, vec1, vec2);
3332

34-
// CHECK: %[[VEC4:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N8U0S0.v8f16.v8f16(i32 -2147483622,
35-
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N8U0S0 %[[MAT1]], i1 true, <8 x half> %[[VEC2]], i32 8, <8 x half> %[[VEC3]], i32 8)
33+
// CHECK: %[[VEC20:.*]] = shufflevector
34+
vector<half, 4> vec20 = (vector<half, 4>)vec2;
35+
36+
// CHECK: %[[VEC4:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N4U0S0.v4f16.v8f16(i32 -2147483622,
37+
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N4U0S0 %[[MAT1]], i1 true, <4 x half> %[[VEC20]], i32 8, <8 x half> %[[VEC3]], i32 8)
3638
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
37-
InterpretedVector<half, 8, ComponentType::F16> interpVec2 = MakeInterpretedVector<ComponentType::F16>(vec2);
39+
InterpretedVector<half, 4, ComponentType::F16> interpVec2 = MakeInterpretedVector<ComponentType::F16>(vec20);
3840
vector<half, 8> vec4 = MultiplyAdd<half>(Mat1, interpVec2, vec3);
3941

4042
// CHECK: %[[RAWLOAD:.*]] = call %dx.types.ResRet.v8i16 @dx.op.rawBufferVectorLoad.v8i16(i32 303,
4143
// CHECK-SAME: %dx.types.Handle %{{[0-9]+}}, i32 4096, i32 undef, i32 2) ; RawBufferVectorLoad(buf,index,elementOffset,alignment)
4244

4345
// CHECK: %[[VEC_BIAS:.*]] = extractvalue %dx.types.ResRet.v8i16 %[[RAWLOAD]], 0
4446

45-
// CHECK: %[[VEC5:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N8U0S0.v8f16.v8i16(i32 -2147483622,
46-
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N8U0S0 %[[MAT1]], i1 true, <8 x half> %[[VEC3]], i32 8, <8 x i16> %[[VEC_BIAS]], i32 2)
47+
// CHECK: %[[VEC5:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N4U0S0.v4f16.v8i16(i32 -2147483622,
48+
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N4U0S0 %[[MAT1]], i1 true, <4 x half> %[[VEC20]], i32 8, <8 x i16> %[[VEC_BIAS]], i32 2)
4749
// CHECK-SAME:; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
4850
VectorRef<ComponentType::I16, 8> memBias = {BAB, 4096};
49-
vector<half, 8> vec5 = MultiplyAdd<half>(Mat1, vec3, memBias);
51+
vector<half, 8> vec5 = MultiplyAdd<half>(Mat1, interpVec2, memBias);
5052

5153
// CHECK: %[[RAWLOAD:.*]] = call %dx.types.ResRet.v8i16 @dx.op.rawBufferVectorLoad.v8i16(i32 303,
5254
// CHECK-SAME: %dx.types.Handle %{{[0-9]+}}, i32 4096, i32 undef, i32 2)
5355
// CHECK-SAME: ; RawBufferVectorLoad(buf,index,elementOffset,alignment)
5456

5557
// CHECK: %[[VEC_BIAS:.*]] = extractvalue %dx.types.ResRet.v8i16 %[[RAWLOAD]], 0
5658

57-
// CHECK: %[[VEC6:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N8U0S0.v8f16.v8i16(i32 -2147483622,
58-
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N8U0S0 %[[MAT1]], i1 true, <8 x half> %[[VEC2]], i32 8, <8 x i16> %[[VEC_BIAS]], i32 2)
59+
// CHECK: %[[VEC6:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N4U0S0.v4f16.v8i16(i32 -2147483622,
60+
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N4U0S0 %[[MAT1]], i1 true, <4 x half> %[[VEC20]], i32 8, <8 x i16> %[[VEC_BIAS]], i32 2)
5961
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
6062
vector<half, 8> vec6 = MultiplyAdd<half>(Mat1, interpVec2, memBias);
6163

utils/hct/gen_intrin_main.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,8 +404,8 @@ uint [[min_sm=6.10]] __builtin_LinAlg_MatrixQueryAccumulatorLayout();
404404
void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiply(out LinAlgMatrix matrixC, in LinAlgMatrix matrixA, in LinAlgMatrix matrixB);
405405
void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiplyAccumulate(ref LinAlgMatrix matrixR, in LinAlgMatrix matrixA, in LinAlgMatrix matrixB, in LinAlgMatrix matrixC);
406406
void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulate(ref LinAlgMatrix matrixC, in LinAlgMatrix matrixLHS, in LinAlgMatrix matrixRHS);
407-
void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiply(out numeric<> ret, in LinAlgMatrix mat, in bool isOutputSigned, in numeric<> input, in uint inputInterp);
408-
void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiplyAdd(out numeric<> ret, in LinAlgMatrix mat, in bool isOutputSigned, in numeric<> input, in uint inputInterp, in numeric<> bias, in uint biasInterp);
407+
void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiply(out numeric<c> ret, in LinAlgMatrix mat, in bool isOutputSigned, in numeric<c2> input, in uint inputInterp);
408+
void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiplyAdd(out numeric<c> ret, in LinAlgMatrix mat, in bool isOutputSigned, in numeric<c2> input, in uint inputInterp, in numeric<c> bias, in uint biasInterp);
409409
void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToDescriptor(in LinAlgMatrix matrix, in RWByteAddressBuffer buf, in uint offset, in uint stride, in uint layout, in uint align);
410410
void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToMemory(in LinAlgMatrix matrix, groupshared numeric[] memory, in uint offset, in uint stride, in uint layout);
411411
void [[min_sm=6.10]] __builtin_LinAlg_MatrixOuterProduct(out LinAlgMatrix ret, in numeric<> vecA, in numeric<> vecB);

0 commit comments

Comments
 (0)