Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tools/clang/lib/Headers/hlsl/dx/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
ComponentEnum MatrixDT>
// clang-format off
typename hlsl::enable_if<
InterpretedVector<InputElTy, VecK, InputInterp>::Size == K,
InterpretedVector<InputElTy, VecK, InputInterp>::Size >= K,
vector<OutputElTy, M> >::type
// clang-format on
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
Expand Down Expand Up @@ -542,7 +542,7 @@ template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
ComponentEnum MatrixDT>
// clang-format off
typename hlsl::enable_if<
InterpretedVector<InputElTy, VecK, InputInterp>::Size == K,
InterpretedVector<InputElTy, VecK, InputInterp>::Size >= K,
Comment thread
tex3d marked this conversation as resolved.
Outdated
vector<OutputElTy, M> >::type
// clang-format on
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
Expand Down
57 changes: 57 additions & 0 deletions tools/clang/test/CodeGenDXIL/hlsl/linalg/api/vectors.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using namespace dx::linalg;
using MatrixATy = Matrix<ComponentType::F16, 8, 4, MatrixUse::A, MatrixScope::Thread>;
using MatrixAccum_8_8_Ty = Matrix<ComponentType::F16, 8, 8, MatrixUse::Accumulator, MatrixScope::Thread>;
using MatrixAccum_8_4_Ty = Matrix<ComponentType::F16, 8, 4, MatrixUse::Accumulator, MatrixScope::Thread>;
using Matrix_7_15_ATy = Matrix<ComponentType::F16, 7, 15, MatrixUse::A, MatrixScope::Thread>;

ByteAddressBuffer BAB : register(t0);

Expand Down Expand Up @@ -87,4 +88,60 @@ void main(uint ID : SV_GroupID) {
half3 ThreeF16 = BAB.Load<half3>(256);
InterpretedVector<uint, 1, ComponentEnum::F8_E4M3FN> convertedPacked2 =
Convert<ComponentEnum::F8_E4M3FN, ComponentEnum::F16>(ThreeF16);

// Test MultiplyAdd with odd sizes
//
vector<half, 15> vecH15 = BAB.Load< vector<half, 15> >(168);
vector<half, 7> vecH7 = BAB.Load< vector<half, 7> >(64);

InterpretedVector<half, 15, ComponentEnum::F16> interpVecH15 = MakeInterpretedVector<ComponentEnum::F16>(vecH15);

// CHECK: %[[MAT_7_15:.*]] = call %dx.types.LinAlgMatrixC8M7N15U0S0 @dx.op.linAlgMatrixLoadFromDescriptor.mC8M7N15U0S0(i32 -2147483634,
// CHECK-SAME: %dx.types.Handle %{{[0-9]+}}, i32 0, i32 16, i32 1, i32 128) ; LinAlgMatrixLoadFromDescriptor(handle,offset,stride,layout,align)
Matrix_7_15_ATy Mat_7_15 = Matrix_7_15_ATy::Load<MatrixLayoutEnum::ColMajor>(BAB, 0, 16);

// CHECK: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC8M7N15U0S0.v15f16.v7f16(i32 -2147483622,
// CHECK-SAME: %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]], i1 true, <15 x half> %{{[0-9]+}}, i32 8, <7 x half> %{{[0-9]+}}, i32 8)
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
vector<half, 7> vec7 = MultiplyAdd<half>(Mat_7_15, vecH15, vecH7);

// CHECK: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC8M7N15U0S0.v15f16.v7f16(i32 -2147483622, %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]],
// CHECK-SAME; i1 true, <15 x half> %{{[0-9]+}}, i32 8, <7 x half> %{{[0-9]+}}, i32 8)
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
vector<half, 7> vec8 = MultiplyAdd<half>(Mat_7_15, interpVecH15, vecH7);

// CHECK: %[[LOAD1:.*]] = call %dx.types.ResRet.v7f16 @dx.op.rawBufferVectorLoad.v7f16(i32 303, %dx.types.Handle %{{[0-9]+}}, i32 512, i32 undef, i32 2)
// CHECK-SAME: ; RawBufferVectorLoad(buf,index,elementOffset,alignment)
// CHECK: %[[MEM_BIAS1:.*]] = extractvalue %dx.types.ResRet.v7f16 %[[LOAD1]], 0
// CHECK: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC8M7N15U0S0.v15f16.v7f16(i32 -2147483622,
// CHECK-SAME: %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]], i1 true, <15 x half> %29, i32 8, <7 x half> %37, i32 8)
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
VectorRef<ComponentType::F16, 7> memBias7 = {BAB, 512};
vector<half, 7> vec9 = MultiplyAdd<half>(Mat_7_15, vecH15, memBias7);

// CHECK: %[[LOAD2:.*]] = call %dx.types.ResRet.v7f16 @dx.op.rawBufferVectorLoad.v7f16(i32 303, %dx.types.Handle %{{[0-9]+}}, i32 512, i32 undef, i32 2)
// CHECK-SAME: ; RawBufferVectorLoad(buf,index,elementOffset,alignment)
// CHECK: %[[MEM_BIAS2:.*]] = extractvalue %dx.types.ResRet.v7f16 %[[LOAD2]], 0
// CHECK-NEXT: %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]], i1 true, <15 x half> %{{[0-9]+}}, i32 8, <7 x half> %[[MEM_BIAS2]], i32 8)
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
vector<half, 7> vec10 = MultiplyAdd<half>(Mat_7_15, interpVecH15, memBias7);

// Test MultiplyAdd with packed input vector
//
// CHECK: %[[INTERP_VEC_H15_PACKED:.*]] = call <4 x i32> @dx.op.linAlgConvert.v4i32.v15f16(i32 -2147483618,
// CHECK-SAME: <15 x half> %{{[0-9]+}}, i32 8, i32 21) ; LinAlgConvert(inputVector,inputInterpretation,outputInterpretation)
InterpretedVector<uint, 4, ComponentEnum::F8_E4M3FN> interpVecH15Packed = Convert<ComponentEnum::F8_E4M3FN, ComponentEnum::F16>(vecH15);

// CHECK: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC8M7N15U0S0.v4i32.v7f16(i32 -2147483622,
// CHECK-SAME: %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]], i1 true, <4 x i32> %43, i32 21, <7 x half> %31, i32 8)
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
vector<half, 7> vec11 = MultiplyAdd<half>(Mat_7_15, interpVecH15Packed, vecH7);

// CHECK: %[[LOAD3:.+]] = call %dx.types.ResRet.v7f16 @dx.op.rawBufferVectorLoad.v7f16(i32 303, %dx.types.Handle %45, i32 512, i32 undef, i32 2)
// CHECK-SAME: ; RawBufferVectorLoad(buf,index,elementOffset,alignment)
// CHECK-NEXT: %[[MEM_BIAS3:.*]] = extractvalue %dx.types.ResRet.v7f16 %46, 0
// CHECK-NEXT: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC8M7N15U0S0.v4i32.v7f16(i32 -2147483622,
// CHECK-SAME: %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]], i1 true, <4 x i32> %[[INTERP_VEC_H15_PACKED]], i32 21, <7 x half> %[[MEM_BIAS3]], i32 8)
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
vector<half, 7> vec12 = MultiplyAdd<half>(Mat_7_15, interpVecH15Packed, memBias7);
}
Loading