diff --git a/tools/clang/lib/Headers/hlsl/dx/linalg.h b/tools/clang/lib/Headers/hlsl/dx/linalg.h index 9b43dcc6cb..ff776f97b0 100644 --- a/tools/clang/lib/Headers/hlsl/dx/linalg.h +++ b/tools/clang/lib/Headers/hlsl/dx/linalg.h @@ -165,11 +165,19 @@ template struct ComponentTypeTraits { static const uint ElementsPerScalar = 4; }; +template struct TypeTraits { + static const ComponentEnum CompType = + (ComponentEnum)dxil::ComponentType::Invalid; +}; + #define __MATRIX_SCALAR_COMPONENT_MAPPING(enum_val, type) \ template <> struct ComponentTypeTraits { \ using Type = type; \ static const bool IsNativeScalar = true; \ static const uint ElementsPerScalar = 1; \ + }; \ + template <> struct TypeTraits { \ + static const ComponentEnum CompType = enum_val; \ }; #if __HLSL_ENABLE_16_BIT @@ -498,14 +506,61 @@ Multiply(Matrix MatrixA, template // clang-format off -typename hlsl::enable_if::value, vector >::type +typename hlsl::enable_if::value && + __detail::TypeTraits::CompType == + __detail::TypeTraits::CompType, + vector >::type // clang-format on MultiplyAdd(Matrix MatrixA, vector Vec, vector Bias) { vector Result; - __builtin_LinAlg_MatrixVectorMultiplyAdd(Result, MatrixA.__handle, - hlsl::is_signed::value, - Vec, MatrixDT, Bias, MatrixDT); + __builtin_LinAlg_MatrixVectorMultiplyAdd( + Result, MatrixA.__handle, hlsl::is_signed::value, Vec, + __detail::TypeTraits::CompType, Bias, + __detail::TypeTraits::CompType); + return Result; +} + +template +// clang-format off +typename hlsl::enable_if::value && + __detail::TypeTraits::CompType != + __detail::TypeTraits::CompType, + vector >::type +// clang-format on +MultiplyAdd(Matrix MatrixA, + vector Vec, vector Bias) { + vector BiasVecConv; + __builtin_LinAlg_Convert(BiasVecConv, Bias, + __detail::TypeTraits::CompType, + __detail::TypeTraits::CompType); + vector Result; + __builtin_LinAlg_MatrixVectorMultiplyAdd( + Result, MatrixA.__handle, hlsl::is_signed::value, Vec, + __detail::TypeTraits::CompType, BiasVecConv, + __detail::TypeTraits::CompType); + return Result; +} + +template +// clang-format off +typename hlsl::enable_if< + VecK == __detail::ScalarCountFromPackedComponents::Value && + __detail::TypeTraits::CompType == + __detail::TypeTraits::CompType, + vector >::type +// clang-format on +MultiplyAdd(Matrix MatrixA, + InterpretedVector InterpVec, + vector Bias) { + vector Result; + __builtin_LinAlg_MatrixVectorMultiplyAdd( + Result, MatrixA.__handle, hlsl::is_signed::value, + InterpVec.Data, InterpVec.Interpretation, Bias, + __detail::TypeTraits::CompType); return Result; } @@ -514,55 +569,121 @@ template // clang-format off typename hlsl::enable_if< - VecK == __detail::ScalarCountFromPackedComponents::Value, + VecK == __detail::ScalarCountFromPackedComponents::Value && + __detail::TypeTraits::CompType != + __detail::TypeTraits::CompType, vector >::type // clang-format on MultiplyAdd(Matrix MatrixA, InterpretedVector InterpVec, vector Bias) { + + vector BiasVecConv; + __builtin_LinAlg_Convert(BiasVecConv, Bias, + __detail::TypeTraits::CompType, + __detail::TypeTraits::CompType); + vector Result; __builtin_LinAlg_MatrixVectorMultiplyAdd( Result, MatrixA.__handle, hlsl::is_signed::value, - InterpVec.Data, InterpVec.Interpretation, Bias, MatrixDT); + InterpVec.Data, InterpVec.Interpretation, BiasVecConv, + __detail::TypeTraits::CompType); + return Result; +} + +template +// clang-format off +typename hlsl::enable_if::value && + __detail::TypeTraits::CompType == BiasInterp, + vector >::type +// clang-format on +MultiplyAdd(Matrix MatrixA, + vector Vec, VectorRef BiasRef) { + using BiasOutputVecTy = vector; + BiasOutputVecTy BiasVec = + BiasRef.Buf.template Load(BiasRef.Offset); + + BiasOutputVecTy Result; + __builtin_LinAlg_MatrixVectorMultiplyAdd(Result, MatrixA.__handle, + hlsl::is_signed::value, + Vec, MatrixDT, BiasVec, BiasInterp); return Result; } -template // clang-format off -typename hlsl::enable_if::value, +typename hlsl::enable_if::value && + __detail::TypeTraits::CompType != BiasInterp, vector >::type // clang-format on MultiplyAdd(Matrix MatrixA, - vector Vec, VectorRef BiasRef) { + vector Vec, VectorRef BiasRef) { using BiasVecTy = - vector::Type, M>; + vector::Type, + __detail::ScalarCountFromPackedComponents::Value>; BiasVecTy BiasVec = BiasRef.Buf.template Load(BiasRef.Offset); + + vector BiasVecConv; + ComponentEnum OutputCompType = __detail::TypeTraits::CompType; + __builtin_LinAlg_Convert(BiasVecConv, BiasVec, BiasInterp, OutputCompType); + vector Result; - __builtin_LinAlg_MatrixVectorMultiplyAdd(Result, MatrixA.__handle, - hlsl::is_signed::value, - Vec, MatrixDT, BiasVec, BiasElTy); + __builtin_LinAlg_MatrixVectorMultiplyAdd( + Result, MatrixA.__handle, hlsl::is_signed::value, Vec, + __detail::TypeTraits::CompType, BiasVecConv, OutputCompType); return Result; } template // clang-format off typename hlsl::enable_if< - VecK == __detail::ScalarCountFromPackedComponents::Value, + VecK == __detail::ScalarCountFromPackedComponents::Value && + __detail::TypeTraits::CompType == BiasInterp, vector >::type // clang-format on MultiplyAdd(Matrix MatrixA, InterpretedVector InterpVec, - VectorRef BiasRef) { + VectorRef BiasRef) { + using BiasOutputVecTy = vector; + BiasOutputVecTy BiasVec = + BiasRef.Buf.template Load(BiasRef.Offset); + + vector Result; + __builtin_LinAlg_MatrixVectorMultiplyAdd( + Result, MatrixA.__handle, hlsl::is_signed::value, + InterpVec.Data, InterpVec.Interpretation, BiasVec, BiasInterp); + return Result; +} + +template +// clang-format off +typename hlsl::enable_if< + VecK == __detail::ScalarCountFromPackedComponents::Value && + __detail::TypeTraits::CompType != BiasInterp, + vector >::type +// clang-format on +MultiplyAdd(Matrix MatrixA, + InterpretedVector InterpVec, + VectorRef BiasRef) { using BiasVecTy = - vector::Type, M>; + vector::Type, + __detail::ScalarCountFromPackedComponents::Value>; BiasVecTy BiasVec = BiasRef.Buf.template Load(BiasRef.Offset); + + ComponentEnum OutputCompType = __detail::TypeTraits::CompType; + vector BiasVecConv; + __builtin_LinAlg_Convert(BiasVecConv, BiasVec, BiasInterp, OutputCompType); + vector Result; __builtin_LinAlg_MatrixVectorMultiplyAdd( Result, MatrixA.__handle, hlsl::is_signed::value, - InterpVec.Data, InterpVec.Interpretation, BiasVec, BiasElTy); + InterpVec.Data, InterpVec.Interpretation, BiasVecConv, OutputCompType); return Result; } diff --git a/tools/clang/test/CodeGenDXIL/hlsl/linalg/api/vectors.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/linalg/api/vectors.hlsl index 58f19b887c..e952f2f721 100644 --- a/tools/clang/test/CodeGenDXIL/hlsl/linalg/api/vectors.hlsl +++ b/tools/clang/test/CodeGenDXIL/hlsl/linalg/api/vectors.hlsl @@ -8,6 +8,7 @@ using MatrixATy = Matrix; using MatrixAccum_8_4_Ty = Matrix; using Matrix_7_15_ATy = Matrix; +using MatrixPacked_7_15_ATy = Matrix; ByteAddressBuffer BAB : register(t0); @@ -46,8 +47,10 @@ void main(uint ID : SV_GroupID) { // CHECK: %[[VEC_BIAS:.*]] = extractvalue %dx.types.ResRet.v8i16 %[[RAWLOAD]], 0 - // CHECK: %[[VEC5:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N4U0S0.v4f16.v8i16(i32 -2147483622, - // CHECK-SAME: %dx.types.LinAlgMatrixC8M8N4U0S0 %[[MAT1]], i1 true, <4 x half> %[[VEC20]], i32 8, <8 x i16> %[[VEC_BIAS]], i32 2) + // CHECK: %[[BIAS_CONV:.*]] = call <8 x half> @dx.op.linAlgConvert.v8f16.v8i16(i32 -2147483618, <8 x i16> %[[VEC_BIAS]], i32 2, i32 8) + // CHECK-SAME: ; LinAlgConvert(inputVector,inputInterpretation,outputInterpretation) + // CHECK: %[[VEC5:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N4U0S0.v4f16.v8f16(i32 -2147483622, + // CHECK-SAME: %dx.types.LinAlgMatrixC8M8N4U0S0 %[[MAT1]], i1 true, <4 x half> %[[VEC20]], i32 8, <8 x half> %[[BIAS_CONV]], i32 8) // CHECK-SAME:; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation) VectorRef memBias = {BAB, 4096}; vector vec5 = MultiplyAdd(Mat1, interpVec2, memBias); @@ -58,8 +61,10 @@ void main(uint ID : SV_GroupID) { // CHECK: %[[VEC_BIAS:.*]] = extractvalue %dx.types.ResRet.v8i16 %[[RAWLOAD]], 0 - // CHECK: %[[VEC6:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N4U0S0.v4f16.v8i16(i32 -2147483622, - // CHECK-SAME: %dx.types.LinAlgMatrixC8M8N4U0S0 %[[MAT1]], i1 true, <4 x half> %[[VEC20]], i32 8, <8 x i16> %[[VEC_BIAS]], i32 2) + // CHECK: %[[BIAS_CONV:.*]] = call <8 x half> @dx.op.linAlgConvert.v8f16.v8i16(i32 -2147483618, <8 x i16> %[[VEC_BIAS]], i32 2, i32 8) + // CHECK-SAME: ; LinAlgConvert(inputVector,inputInterpretation,outputInterpretation) + // CHECK: %[[VEC6:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N4U0S0.v4f16.v8f16(i32 -2147483622, + // CHECK-SAME: %dx.types.LinAlgMatrixC8M8N4U0S0 %[[MAT1]], i1 true, <4 x half> %[[VEC20]], i32 8, <8 x half> %[[BIAS_CONV]], i32 8) // CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation) vector vec6 = MultiplyAdd(Mat1, interpVec2, memBias); @@ -77,13 +82,13 @@ void main(uint ID : SV_GroupID) { InterpretedVector convertedVec; convertedVec = Convert(vec6); - // CHECK: call <4 x i32> @dx.op.linAlgConvert.v4i32.v16f16(i32 -2147483618, <16 x half> %21, i32 8, i32 21) + // CHECK: call <4 x i32> @dx.op.linAlgConvert.v4i32.v16f16(i32 -2147483618, <16 x half> %{{[0-9]+}}, i32 8, i32 21) // CHECK: ; LinAlgConvert(inputVector,inputInterpretation,outputInterpretation) typedef vector half16; half16 srcF16 = BAB.Load(128); InterpretedVector convertedPacked = Convert(srcF16); - // CHECK: call <1 x i32> @dx.op.linAlgConvert.v1i32.v3f16(i32 -2147483618, <3 x half> %25, i32 8, i32 21) + // CHECK: call <1 x i32> @dx.op.linAlgConvert.v1i32.v3f16(i32 -2147483618, <3 x half> %{{[0-9]+}}, i32 8, i32 21) // CHECK-SAME: ; LinAlgConvert(inputVector,inputInterpretation,outputInterpretation) half3 ThreeF16 = BAB.Load(256); InterpretedVector convertedPacked2 = @@ -112,16 +117,16 @@ void main(uint ID : SV_GroupID) { // CHECK: %[[LOAD1:.*]] = call %dx.types.ResRet.v7f16 @dx.op.rawBufferVectorLoad.v7f16(i32 303, %dx.types.Handle %{{[0-9]+}}, i32 512, i32 undef, i32 2) // CHECK-SAME: ; RawBufferVectorLoad(buf,index,elementOffset,alignment) - // CHECK: %[[MEM_BIAS1:.*]] = extractvalue %dx.types.ResRet.v7f16 %[[LOAD1]], 0 - // CHECK: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC8M7N15U0S0.v15f16.v7f16(i32 -2147483622, - // CHECK-SAME: %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]], i1 true, <15 x half> %29, i32 8, <7 x half> %37, i32 8) + // CHECK-NEXT: %[[MEM_BIAS1:.*]] = extractvalue %dx.types.ResRet.v7f16 %[[LOAD1]], 0 + // CHECK-NEXT: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC8M7N15U0S0.v15f16.v7f16(i32 -2147483622, + // CHECK-SAME: %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]], i1 true, <15 x half> %{{[0-9]+}}, i32 8, <7 x half> %{{[0-9]+}}, i32 8) // CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation) VectorRef memBias7 = {BAB, 512}; vector vec9 = MultiplyAdd(Mat_7_15, vecH15, memBias7); // CHECK: %[[LOAD2:.*]] = call %dx.types.ResRet.v7f16 @dx.op.rawBufferVectorLoad.v7f16(i32 303, %dx.types.Handle %{{[0-9]+}}, i32 512, i32 undef, i32 2) // CHECK-SAME: ; RawBufferVectorLoad(buf,index,elementOffset,alignment) - // CHECK: %[[MEM_BIAS2:.*]] = extractvalue %dx.types.ResRet.v7f16 %[[LOAD2]], 0 + // CHECK-NEXT: %[[MEM_BIAS2:.*]] = extractvalue %dx.types.ResRet.v7f16 %[[LOAD2]], 0 // CHECK-NEXT: %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]], i1 true, <15 x half> %{{[0-9]+}}, i32 8, <7 x half> %[[MEM_BIAS2]], i32 8) // CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation) vector vec10 = MultiplyAdd(Mat_7_15, interpVecH15, memBias7); @@ -133,15 +138,51 @@ void main(uint ID : SV_GroupID) { InterpretedVector interpVecH15Packed = Convert(vecH15); // CHECK: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC8M7N15U0S0.v4i32.v7f16(i32 -2147483622, - // CHECK-SAME: %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]], i1 true, <4 x i32> %43, i32 21, <7 x half> %31, i32 8) + // CHECK-SAME: %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]], i1 true, <4 x i32> %{{[0-9]+}}, i32 21, <7 x half> %{{[0-9]+}}, i32 8) // CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation) vector vec11 = MultiplyAdd(Mat_7_15, interpVecH15Packed, vecH7); - // CHECK: %[[LOAD3:.+]] = call %dx.types.ResRet.v7f16 @dx.op.rawBufferVectorLoad.v7f16(i32 303, %dx.types.Handle %45, i32 512, i32 undef, i32 2) + // CHECK: %[[LOAD3:.+]] = call %dx.types.ResRet.v7f16 @dx.op.rawBufferVectorLoad.v7f16(i32 303, %dx.types.Handle %{{[0-9]+}}, i32 512, i32 undef, i32 2) // CHECK-SAME: ; RawBufferVectorLoad(buf,index,elementOffset,alignment) - // CHECK-NEXT: %[[MEM_BIAS3:.*]] = extractvalue %dx.types.ResRet.v7f16 %46, 0 + // CHECK-NEXT: %[[MEM_BIAS3:.*]] = extractvalue %dx.types.ResRet.v7f16 %{{[0-9]+}}, 0 // CHECK-NEXT: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC8M7N15U0S0.v4i32.v7f16(i32 -2147483622, // CHECK-SAME: %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]], i1 true, <4 x i32> %[[INTERP_VEC_H15_PACKED]], i32 21, <7 x half> %[[MEM_BIAS3]], i32 8) // CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation) vector vec12 = MultiplyAdd(Mat_7_15, interpVecH15Packed, memBias7); + + // Test Convert and MultiplyAdd with odd sizes and packed types + + // CHECK: %[[MAT_7_15_PACKED:.*]] = call %dx.types.LinAlgMatrixC21M7N15U0S0 @dx.op.linAlgMatrixLoadFromDescriptor.mC21M7N15U0S0(i32 -2147483634, + // CHECK-SAME: %dx.types.Handle %{{[0-9]+}}, i32 0, i32 16, i32 1, i32 128) ; LinAlgMatrixLoadFromDescriptor(handle,offset,stride,layout,align) + MatrixPacked_7_15_ATy Mat_7_15_Packed = MatrixPacked_7_15_ATy::Load(BAB, 0, 16); + + // CHECK: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC21M7N15U0S0.v15f16.v7f16(i32 -2147483622, + // CHECK-SAME: %dx.types.LinAlgMatrixC21M7N15U0S0 %[[MAT_7_15_PACKED]], i1 true, <15 x half> %{{[0-9]+}}, i32 8, <7 x half> %{{[0-9]+}}, i32 8) + // CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation) + vector vec21 = MultiplyAdd(Mat_7_15_Packed, vecH15, vecH7); + + // CHECK: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC21M7N15U0S0.v4i32.v7f16(i32 -2147483622, %dx.types.LinAlgMatrixC21M7N15U0S0 %[[MAT_7_15_PACKED]], + // CHECK-SAME: i1 true, <4 x i32> %[[INTERP_VEC_H15_PACKED]], i32 21, <7 x half> %{{[0-9]+}}, i32 8) + // CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation) + vector vec22 = MultiplyAdd(Mat_7_15_Packed, interpVecH15Packed, vecH7); + + // CHECK: %[[LOAD4:.*]] = call %dx.types.ResRet.v2i32 @dx.op.rawBufferVectorLoad.v2i32(i32 303, %dx.types.Handle %{{[0-9]+}}, i32 512, i32 undef, i32 4) + // CHECK-SAME: ; RawBufferVectorLoad(buf,index,elementOffset,alignment) + // CHECK-NEXT: %[[MEM_BIAS_PACKED1:.*]] = extractvalue %dx.types.ResRet.v2i32 %[[LOAD4]], 0 + // CHECK-NEXT: %[[MEM_BIAS_CONV1:.*]] = call <7 x half> @dx.op.linAlgConvert.v7f16.v2i32(i32 -2147483618, + // CHECK-SAME: <2 x i32> %[[MEM_BIAS_PACKED1]], i32 21, i32 8) ; LinAlgConvert(inputVector,inputInterpretation,outputInterpretation) + // CHECK-NEXT: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC21M7N15U0S0.v15f16.v7f16(i32 -2147483622, + // CHECK-SAME: %dx.types.LinAlgMatrixC21M7N15U0S0 %[[MAT_7_15_PACKED]], i1 true, <15 x half> %{{[0-9]+}}, i32 8, <7 x half> %[[MEM_BIAS_CONV1]], i32 8) + // CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation) + VectorRef memBias7Packed = {BAB, 512}; + vector vec23 = MultiplyAdd(Mat_7_15_Packed, vecH15, memBias7Packed); + + // CHECK: %[[LOAD5:.*]] = call %dx.types.ResRet.v2i32 @dx.op.rawBufferVectorLoad.v2i32(i32 303, %dx.types.Handle %{{[0-9]+}}, i32 512, i32 undef, i32 4) + // CHECK-SAME: ; RawBufferVectorLoad(buf,index,elementOffset,alignment) + // CHECK-NEXT: %[[MEM_BIAS_PACKED2:.*]] = extractvalue %dx.types.ResRet.v2i32 %[[LOAD5]], 0 + // CHECK-NEXT: %[[MEM_BIAS_CONV2:.*]] = call <7 x half> @dx.op.linAlgConvert.v7f16.v2i32(i32 -2147483618, <2 x i32> %[[MEM_BIAS_PACKED2]], i32 21, i32 8) ; LinAlgConvert(inputVector,inputInterpretation,outputInterpretation) + // CHECK-NEXT: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC21M7N15U0S0.v4i32.v7f16(i32 -2147483622, + // CHECK-SAME: %dx.types.LinAlgMatrixC21M7N15U0S0 %[[MAT_7_15_PACKED]], i1 true, <4 x i32> %[[INTERP_VEC_H15_PACKED]], i32 21, <7 x half> %[[MEM_BIAS_CONV2]], i32 8) + // CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation) + vector vec24 = MultiplyAdd(Mat_7_15_Packed, interpVecH15Packed, memBias7Packed); }