Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions tools/clang/lib/Headers/hlsl/dx/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,14 @@ template <SIZE_TYPE MVal, SIZE_TYPE NVal> struct DimMN<MVal, NVal, true> {
static const SIZE_TYPE N = MVal;
};

template <ComponentEnum CompTy, SIZE_TYPE PackedComponentCount>
struct ScalarCountFromPackedComponents {
static const SIZE_TYPE ElementsPerScalar =
ComponentTypeTraits<CompTy>::ElementsPerScalar;
static const SIZE_TYPE Value =
(PackedComponentCount + ElementsPerScalar - 1) / ElementsPerScalar;
};

} // namespace __detail

template <ComponentEnum ElementType, uint DimA> struct VectorRef {
Expand Down Expand Up @@ -506,7 +514,7 @@ template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
ComponentEnum MatrixDT>
// clang-format off
typename hlsl::enable_if<
InterpretedVector<InputElTy, VecK, InputInterp>::Size == K,
VecK == __detail::ScalarCountFromPackedComponents<InputInterp, K>::Value,
vector<OutputElTy, M> >::type
// clang-format on
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
Expand Down Expand Up @@ -542,7 +550,7 @@ template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
ComponentEnum MatrixDT>
// clang-format off
typename hlsl::enable_if<
InterpretedVector<InputElTy, VecK, InputInterp>::Size == K,
VecK == __detail::ScalarCountFromPackedComponents<InputInterp, K>::Value,
vector<OutputElTy, M> >::type
// clang-format on
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
Expand Down
57 changes: 57 additions & 0 deletions tools/clang/test/CodeGenDXIL/hlsl/linalg/api/vectors.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using namespace dx::linalg;
using MatrixATy = Matrix<ComponentType::F16, 8, 4, MatrixUse::A, MatrixScope::Thread>;
using MatrixAccum_8_8_Ty = Matrix<ComponentType::F16, 8, 8, MatrixUse::Accumulator, MatrixScope::Thread>;
using MatrixAccum_8_4_Ty = Matrix<ComponentType::F16, 8, 4, MatrixUse::Accumulator, MatrixScope::Thread>;
using Matrix_7_15_ATy = Matrix<ComponentType::F16, 7, 15, MatrixUse::A, MatrixScope::Thread>;

ByteAddressBuffer BAB : register(t0);

Expand Down Expand Up @@ -87,4 +88,60 @@ void main(uint ID : SV_GroupID) {
half3 ThreeF16 = BAB.Load<half3>(256);
InterpretedVector<uint, 1, ComponentEnum::F8_E4M3FN> convertedPacked2 =
Convert<ComponentEnum::F8_E4M3FN, ComponentEnum::F16>(ThreeF16);

// Test MultiplyAdd with odd sizes
//
vector<half, 15> vecH15 = BAB.Load< vector<half, 15> >(168);
vector<half, 7> vecH7 = BAB.Load< vector<half, 7> >(64);

InterpretedVector<half, 15, ComponentEnum::F16> interpVecH15 = MakeInterpretedVector<ComponentEnum::F16>(vecH15);

// CHECK: %[[MAT_7_15:.*]] = call %dx.types.LinAlgMatrixC8M7N15U0S0 @dx.op.linAlgMatrixLoadFromDescriptor.mC8M7N15U0S0(i32 -2147483634,
// CHECK-SAME: %dx.types.Handle %{{[0-9]+}}, i32 0, i32 16, i32 1, i32 128) ; LinAlgMatrixLoadFromDescriptor(handle,offset,stride,layout,align)
Matrix_7_15_ATy Mat_7_15 = Matrix_7_15_ATy::Load<MatrixLayoutEnum::ColMajor>(BAB, 0, 16);

// CHECK: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC8M7N15U0S0.v15f16.v7f16(i32 -2147483622,
// CHECK-SAME: %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]], i1 true, <15 x half> %{{[0-9]+}}, i32 8, <7 x half> %{{[0-9]+}}, i32 8)
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
vector<half, 7> vec7 = MultiplyAdd<half>(Mat_7_15, vecH15, vecH7);

// CHECK: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC8M7N15U0S0.v15f16.v7f16(i32 -2147483622, %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]],
// CHECK-SAME; i1 true, <15 x half> %{{[0-9]+}}, i32 8, <7 x half> %{{[0-9]+}}, i32 8)
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
vector<half, 7> vec8 = MultiplyAdd<half>(Mat_7_15, interpVecH15, vecH7);

// CHECK: %[[LOAD1:.*]] = call %dx.types.ResRet.v7f16 @dx.op.rawBufferVectorLoad.v7f16(i32 303, %dx.types.Handle %{{[0-9]+}}, i32 512, i32 undef, i32 2)
// CHECK-SAME: ; RawBufferVectorLoad(buf,index,elementOffset,alignment)
// CHECK: %[[MEM_BIAS1:.*]] = extractvalue %dx.types.ResRet.v7f16 %[[LOAD1]], 0
// CHECK: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC8M7N15U0S0.v15f16.v7f16(i32 -2147483622,
// CHECK-SAME: %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]], i1 true, <15 x half> %29, i32 8, <7 x half> %37, i32 8)
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
VectorRef<ComponentType::F16, 7> memBias7 = {BAB, 512};
vector<half, 7> vec9 = MultiplyAdd<half>(Mat_7_15, vecH15, memBias7);

// CHECK: %[[LOAD2:.*]] = call %dx.types.ResRet.v7f16 @dx.op.rawBufferVectorLoad.v7f16(i32 303, %dx.types.Handle %{{[0-9]+}}, i32 512, i32 undef, i32 2)
// CHECK-SAME: ; RawBufferVectorLoad(buf,index,elementOffset,alignment)
// CHECK: %[[MEM_BIAS2:.*]] = extractvalue %dx.types.ResRet.v7f16 %[[LOAD2]], 0
// CHECK-NEXT: %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]], i1 true, <15 x half> %{{[0-9]+}}, i32 8, <7 x half> %[[MEM_BIAS2]], i32 8)
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
vector<half, 7> vec10 = MultiplyAdd<half>(Mat_7_15, interpVecH15, memBias7);

// Test MultiplyAdd with packed input vector
//
// CHECK: %[[INTERP_VEC_H15_PACKED:.*]] = call <4 x i32> @dx.op.linAlgConvert.v4i32.v15f16(i32 -2147483618,
// CHECK-SAME: <15 x half> %{{[0-9]+}}, i32 8, i32 21) ; LinAlgConvert(inputVector,inputInterpretation,outputInterpretation)
InterpretedVector<uint, 4, ComponentEnum::F8_E4M3FN> interpVecH15Packed = Convert<ComponentEnum::F8_E4M3FN, ComponentEnum::F16>(vecH15);

// CHECK: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC8M7N15U0S0.v4i32.v7f16(i32 -2147483622,
// CHECK-SAME: %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]], i1 true, <4 x i32> %43, i32 21, <7 x half> %31, i32 8)
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
vector<half, 7> vec11 = MultiplyAdd<half>(Mat_7_15, interpVecH15Packed, vecH7);

// CHECK: %[[LOAD3:.+]] = call %dx.types.ResRet.v7f16 @dx.op.rawBufferVectorLoad.v7f16(i32 303, %dx.types.Handle %45, i32 512, i32 undef, i32 2)
// CHECK-SAME: ; RawBufferVectorLoad(buf,index,elementOffset,alignment)
// CHECK-NEXT: %[[MEM_BIAS3:.*]] = extractvalue %dx.types.ResRet.v7f16 %46, 0
// CHECK-NEXT: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC8M7N15U0S0.v4i32.v7f16(i32 -2147483622,
// CHECK-SAME: %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]], i1 true, <4 x i32> %[[INTERP_VEC_H15_PACKED]], i32 21, <7 x half> %[[MEM_BIAS3]], i32 8)
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
vector<half, 7> vec12 = MultiplyAdd<half>(Mat_7_15, interpVecH15Packed, memBias7);
}
Loading