Skip to content

Commit 8cb7673

Browse files
authored
[SM6.10] Add Convert function to linalg.h (#8328)
Adds `Convert` function to `linalg.h`. This function converts a vector to an `InterpretedVector` of a different component type. It has been added to the spec in microsoft/hlsl-specs#819. It is implemented by calling a built-in function that was added in #8308.
1 parent 3531468 commit 8cb7673

2 files changed

Lines changed: 22 additions & 0 deletions

File tree

tools/clang/lib/Headers/hlsl/dx/linalg.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,12 @@ __MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::I64, int64_t)
185185
__MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::U64, uint64_t)
186186
__MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::F64, double)
187187

188+
template <ComponentEnum DstTy, ComponentEnum SrcTy, int SrcN> struct DstN {
189+
static const int Value =
190+
(SrcN * ComponentTypeTraits<SrcTy>::ElementsPerScalar) /
191+
ComponentTypeTraits<DstTy>::ElementsPerScalar;
192+
};
193+
188194
} // namespace __detail
189195

190196
template <ComponentEnum ElementType, uint DimA> struct VectorRef {
@@ -205,6 +211,17 @@ InterpretedVector<T, N, DT> MakeInterpretedVector(vector<T, N> Vec) {
205211
return IV;
206212
}
207213

214+
template <ComponentEnum DestTy, ComponentEnum OriginTy, typename T, int N>
215+
InterpretedVector<typename __detail::ComponentTypeTraits<DestTy>::Type,
216+
__detail::DstN<DestTy, OriginTy, N>::Value, DestTy>
217+
Convert(vector<T, N> Vec) {
218+
vector<typename __detail::ComponentTypeTraits<DestTy>::Type,
219+
__detail::DstN<DestTy, OriginTy, N>::Value>
220+
Result;
221+
__builtin_LinAlg_Convert(Result, Vec, OriginTy, DestTy);
222+
return MakeInterpretedVector<DestTy>(Result);
223+
}
224+
208225
template <ComponentEnum ComponentTy, SIZE_TYPE M, SIZE_TYPE N,
209226
MatrixUseEnum Use, MatrixScopeEnum Scope>
210227
class Matrix {

tools/clang/test/CodeGenDXIL/hlsl/linalg/api/vectors.hlsl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,9 @@ void main(uint ID : SV_GroupID) {
6363
// CHECK-SAME: @dx.op.linAlgMatrixOuterProduct.mC8M8N8U2S0.v8f16.v8f16(i32 -2147483619,
6464
// CHECK-SAME: <8 x half> %[[VEC5]], <8 x half> %[[VEC6]]) ; LinAlgMatrixOuterProduct(vectorA,vectorB)
6565
MatrixAccumTy AccumMatrix = OuterProduct<ComponentType::F16>(vec5, vec6);
66+
67+
// CHECK: %[[CONV_VEC:.*]] = call <8 x float> @dx.op.linAlgConvert.v8f32.v8f16(i32 -2147483618,
68+
// CHECK-SAME: <8 x half> %[[VEC6]], i32 8, i32 9) ; LinAlgConvert(inputVector,inputInterpretation,outputInterpretation)
69+
InterpretedVector<float, 8, ComponentType::F32> convertedVec;
70+
convertedVec = Convert<ComponentType::F32, ComponentType::F16>(vec6);
6671
}

0 commit comments

Comments
 (0)