Skip to content

Commit 4220939

Browse files
committed
Update after isOutputSigned argument was added to MatVecMultiply*
1 parent f8f2d31 commit 4220939

2 files changed

Lines changed: 39 additions & 19 deletions

File tree

tools/clang/lib/Headers/hlsl/dx/linalg.h

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,23 @@ __ARITHMETIC_TYPE(half)
3232
__ARITHMETIC_TYPE(float)
3333
__ARITHMETIC_TYPE(double)
3434

35+
template <typename T> struct is_signed {
36+
static const bool value = true;
37+
};
38+
39+
#define __UNSIGNED_TYPE(type) \
40+
template <> struct is_signed<type> { \
41+
static const bool value = false; \
42+
};
43+
44+
#if __HLSL_ENABLE_16_BIT
45+
__UNSIGNED_TYPE(uint16_t)
46+
#endif
47+
__UNSIGNED_TYPE(uint)
48+
__UNSIGNED_TYPE(uint64_t)
49+
50+
#undef __UNSIGNED_TYPE
51+
3552
template <bool B, typename T> struct enable_if {};
3653

3754
template <typename T> struct enable_if<true, T> {
@@ -427,8 +444,8 @@ typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value, vector<OutputElT
427444
Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
428445
vector<InputElTy, K> Vec) {
429446
vector<OutputElTy, M> Result;
430-
__builtin_LinAlg_MatrixVectorMultiply(Result, MatrixA.__handle, Vec,
431-
MatrixDT);
447+
__builtin_LinAlg_MatrixVectorMultiply(
448+
Result, MatrixA.__handle, hlsl::is_signed<OutputElTy>::value, Vec, MatrixDT);
432449
return Result;
433450
}
434451

@@ -440,7 +457,8 @@ typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value, vector<OutputElT
440457
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
441458
vector<InputElTy, M> Vec, vector<BiasElTy, K> Bias) {
442459
vector<OutputElTy, K> Result;
443-
__builtin_LinAlg_MatrixVectorMultiplyAdd(Result, MatrixA.__handle, Vec,
460+
__builtin_LinAlg_MatrixVectorMultiplyAdd(Result, MatrixA.__handle,
461+
hlsl::is_signed<OutputElTy>::value, Vec,
444462
MatrixDT, Bias, MatrixDT);
445463
return Result;
446464
}
@@ -458,23 +476,25 @@ MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
458476
vector<BiasElTy, K> Bias) {
459477
vector<OutputElTy, K> Result;
460478
__builtin_LinAlg_MatrixVectorMultiplyAdd(
461-
Result, MatrixA.__handle, InterpVec.Data, InterpVec.Interpretation, Bias,
462-
MatrixDT);
479+
Result, MatrixA.__handle, hlsl::is_signed<OutputElTy>::value, InterpVec.Data,
480+
InterpVec.Interpretation, Bias, MatrixDT);
463481
return Result;
464482
}
465483

466484
template <typename OutputElTy, typename InputElTy, ComponentEnum BiasElTy,
467485
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
468486
// clang-format off
469-
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value, vector<OutputElTy, K> >::type
487+
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value,
488+
vector<OutputElTy, K> >::type
470489
// clang-format on
471490
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
472491
vector<InputElTy, M> Vec, VectorRef<BiasElTy, K> BiasRef) {
473492
using BiasVecTy =
474493
vector<typename __detail::ComponentTypeTraits<BiasElTy>::Type, K>;
475494
BiasVecTy BiasVec = BiasRef.Buf.template Load<BiasVecTy>(BiasRef.Offset);
476495
vector<OutputElTy, K> Result;
477-
__builtin_LinAlg_MatrixVectorMultiplyAdd(Result, MatrixA.__handle, Vec,
496+
__builtin_LinAlg_MatrixVectorMultiplyAdd(Result, MatrixA.__handle,
497+
hlsl::is_signed<OutputElTy>::value, Vec,
478498
MatrixDT, BiasVec, BiasElTy);
479499
return Result;
480500
}
@@ -495,8 +515,8 @@ MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
495515
BiasVecTy BiasVec = BiasRef.Buf.template Load<BiasVecTy>(BiasRef.Offset);
496516
vector<OutputElTy, K> Result;
497517
__builtin_LinAlg_MatrixVectorMultiplyAdd(
498-
Result, MatrixA.__handle, InterpVec.Data, InterpVec.Interpretation,
499-
BiasVec, BiasElTy);
518+
Result, MatrixA.__handle, hlsl::is_signed<OutputElTy>::value, InterpVec.Data,
519+
InterpVec.Interpretation, BiasVec, BiasElTy);
500520
return Result;
501521
}
502522

tools/clang/test/CodeGenDXIL/hlsl/linalg/api/vectors.hlsl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,20 @@ void main(uint ID : SV_GroupID) {
2020
vector<half, 8> vec1 = 10.3f;
2121

2222
// CHECK: %[[VEC2:.*]] = call <8 x half> @dx.op.linAlgMatVecMul.v8f16.mC8M8N8U0S0.v8f16(i32 -2147483623,
23-
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N8U0S0 %3, <8 x half> <half 0xH4926, half 0xH4926, half 0xH4926,
23+
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N8U0S0 %3, i1 true, <8 x half> <half 0xH4926, half 0xH4926, half 0xH4926,
2424
// CHECK-SAME: half 0xH4926, half 0xH4926, half 0xH4926, half 0xH4926, half 0xH4926>, i32 8)
25-
// CHECK-SAME: ; LinAlgMatVecMul(matrix,inputVector,interpretation)
25+
// CHECK-SAME: ; LinAlgMatVecMul(matrix,isOutputSigned,inputVector,interpretation)
2626
vector<half, 8> vec2 = Multiply<half>(Mat1, vec1);
2727

2828
// CHECK: %[[VEC3:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N8U0S0.v8f16.v8f16(i32 -2147483622,
29-
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N8U0S0 %[[MAT1]], <8 x half> <half 0xH4926, half 0xH4926, half 0xH4926,
29+
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N8U0S0 %[[MAT1]], i1 true, <8 x half> <half 0xH4926, half 0xH4926, half 0xH4926,
3030
// CHECK-SAME: half 0xH4926, half 0xH4926, half 0xH4926, half 0xH4926, half 0xH4926>, i32 8, <8 x half> %[[VEC2]], i32 8)
31-
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,inputVector,inputInterpretation,biasVector,biasInterpretation)
31+
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
3232
vector<half, 8> vec3 = MultiplyAdd<half>(Mat1, vec1, vec2);
3333

3434
// CHECK: %[[VEC4:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N8U0S0.v8f16.v8f16(i32 -2147483622,
35-
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N8U0S0 %[[MAT1]], <8 x half> %[[VEC2]], i32 8, <8 x half> %[[VEC3]], i32 8)
36-
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,inputVector,inputInterpretation,biasVector,biasInterpretation)
35+
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N8U0S0 %[[MAT1]], i1 true, <8 x half> %[[VEC2]], i32 8, <8 x half> %[[VEC3]], i32 8)
36+
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
3737
InterpretedVector<half, 8, ComponentType::F16> interpVec2 = MakeInterpretedVector<ComponentType::F16>(vec2);
3838
vector<half, 8> vec4 = MultiplyAdd<half>(Mat1, interpVec2, vec3);
3939

@@ -43,8 +43,8 @@ void main(uint ID : SV_GroupID) {
4343
// CHECK: %[[VEC_BIAS:.*]] = extractvalue %dx.types.ResRet.v8i16 %[[RAWLOAD]], 0
4444

4545
// CHECK: %[[VEC5:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N8U0S0.v8f16.v8i16(i32 -2147483622,
46-
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N8U0S0 %[[MAT1]], <8 x half> %[[VEC3]], i32 8, <8 x i16> %[[VEC_BIAS]], i32 2)
47-
// CHECK-SAME:; LinAlgMatVecMulAdd(matrix,inputVector,inputInterpretation,biasVector,biasInterpretation)
46+
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N8U0S0 %[[MAT1]], i1 true, <8 x half> %[[VEC3]], i32 8, <8 x i16> %[[VEC_BIAS]], i32 2)
47+
// CHECK-SAME:; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
4848
VectorRef<ComponentType::I16, 8> memBias = {BAB, 4096};
4949
vector<half, 8> vec5 = MultiplyAdd<half>(Mat1, vec3, memBias);
5050

@@ -55,8 +55,8 @@ void main(uint ID : SV_GroupID) {
5555
// CHECK: %[[VEC_BIAS:.*]] = extractvalue %dx.types.ResRet.v8i16 %[[RAWLOAD]], 0
5656

5757
// CHECK: %[[VEC6:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N8U0S0.v8f16.v8i16(i32 -2147483622,
58-
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N8U0S0 %[[MAT1]], <8 x half> %[[VEC2]], i32 8, <8 x i16> %[[VEC_BIAS]], i32 2)
59-
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,inputVector,inputInterpretation,biasVector,biasInterpretation)
58+
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N8U0S0 %[[MAT1]], i1 true, <8 x half> %[[VEC2]], i32 8, <8 x i16> %[[VEC_BIAS]], i32 2)
59+
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
6060
vector<half, 8> vec6 = MultiplyAdd<half>(Mat1, interpVec2, memBias);
6161

6262
// CHECK: %[[ACCUM:.*]] = call %dx.types.LinAlgMatrixC8M8N8U2S0

0 commit comments

Comments
 (0)