@@ -7,6 +7,7 @@ using namespace dx::linalg;
77using MatrixATy = Matrix<ComponentType::F16, 8 , 4 , MatrixUse::A, MatrixScope::Thread>;
88using MatrixAccum_8_8_Ty = Matrix<ComponentType::F16, 8 , 8 , MatrixUse::Accumulator, MatrixScope::Thread>;
99using MatrixAccum_8_4_Ty = Matrix<ComponentType::F16, 8 , 4 , MatrixUse::Accumulator, MatrixScope::Thread>;
10+ using Matrix_7_15_ATy = Matrix<ComponentType::F16, 7 , 15 , MatrixUse::A, MatrixScope::Thread>;
1011
1112ByteAddressBuffer BAB : register (t0);
1213
@@ -87,4 +88,60 @@ void main(uint ID : SV_GroupID) {
8788 half3 ThreeF16 = BAB.Load<half3 >(256 );
8889 InterpretedVector<uint , 1 , ComponentEnum::F8_E4M3FN> convertedPacked2 =
8990 Convert<ComponentEnum::F8_E4M3FN, ComponentEnum::F16>(ThreeF16);
91+
92+ // Test MultiplyAdd with odd sizes
93+ //
94+ vector <half , 15 > vecH15 = BAB.Load< vector <half , 15 > >(168 );
95+ vector <half , 7 > vecH7 = BAB.Load< vector <half , 7 > >(64 );
96+
97+ InterpretedVector<half , 15 , ComponentEnum::F16> interpVecH15 = MakeInterpretedVector<ComponentEnum::F16>(vecH15);
98+
99+ // CHECK: %[[MAT_7_15:.*]] = call %dx.types.LinAlgMatrixC8M7N15U0S0 @dx.op.linAlgMatrixLoadFromDescriptor.mC8M7N15U0S0(i32 -2147483634,
100+ // CHECK-SAME: %dx.types.Handle %{{[0-9]+}}, i32 0, i32 16, i32 1, i32 128) ; LinAlgMatrixLoadFromDescriptor(handle,offset,stride,layout,align)
101+ Matrix_7_15_ATy Mat_7_15 = Matrix_7_15_ATy::Load<MatrixLayoutEnum::ColMajor>(BAB, 0 , 16 );
102+
103+ // CHECK: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC8M7N15U0S0.v15f16.v7f16(i32 -2147483622,
104+ // CHECK-SAME: %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]], i1 true, <15 x half> %{{[0-9]+}}, i32 8, <7 x half> %{{[0-9]+}}, i32 8)
105+ // CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
106+ vector <half , 7 > vec7 = MultiplyAdd<half >(Mat_7_15, vecH15, vecH7);
107+
108+ // CHECK: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC8M7N15U0S0.v15f16.v7f16(i32 -2147483622, %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]],
109+ // CHECK-SAME; i1 true, <15 x half> %{{[0-9]+}}, i32 8, <7 x half> %{{[0-9]+}}, i32 8)
110+ // CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
111+ vector <half , 7 > vec8 = MultiplyAdd<half >(Mat_7_15, interpVecH15, vecH7);
112+
113+ // CHECK: %[[LOAD1:.*]] = call %dx.types.ResRet.v7f16 @dx.op.rawBufferVectorLoad.v7f16(i32 303, %dx.types.Handle %{{[0-9]+}}, i32 512, i32 undef, i32 2)
114+ // CHECK-SAME: ; RawBufferVectorLoad(buf,index,elementOffset,alignment)
115+ // CHECK: %[[MEM_BIAS1:.*]] = extractvalue %dx.types.ResRet.v7f16 %[[LOAD1]], 0
116+ // CHECK: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC8M7N15U0S0.v15f16.v7f16(i32 -2147483622,
117+ // CHECK-SAME: %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]], i1 true, <15 x half> %29, i32 8, <7 x half> %37, i32 8)
118+ // CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
119+ VectorRef<ComponentType::F16, 7 > memBias7 = {BAB, 512 };
120+ vector <half , 7 > vec9 = MultiplyAdd<half >(Mat_7_15, vecH15, memBias7);
121+
122+ // CHECK: %[[LOAD2:.*]] = call %dx.types.ResRet.v7f16 @dx.op.rawBufferVectorLoad.v7f16(i32 303, %dx.types.Handle %{{[0-9]+}}, i32 512, i32 undef, i32 2)
123+ // CHECK-SAME: ; RawBufferVectorLoad(buf,index,elementOffset,alignment)
124+ // CHECK: %[[MEM_BIAS2:.*]] = extractvalue %dx.types.ResRet.v7f16 %[[LOAD2]], 0
125+ // CHECK-NEXT: %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]], i1 true, <15 x half> %{{[0-9]+}}, i32 8, <7 x half> %[[MEM_BIAS2]], i32 8)
126+ // CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
127+ vector <half , 7 > vec10 = MultiplyAdd<half >(Mat_7_15, interpVecH15, memBias7);
128+
129+ // Test MultiplyAdd with packed input vector
130+ //
131+ // CHECK: %[[INTERP_VEC_H15_PACKED:.*]] = call <4 x i32> @dx.op.linAlgConvert.v4i32.v15f16(i32 -2147483618,
132+ // CHECK-SAME: <15 x half> %{{[0-9]+}}, i32 8, i32 21) ; LinAlgConvert(inputVector,inputInterpretation,outputInterpretation)
133+ InterpretedVector<uint , 4 , ComponentEnum::F8_E4M3FN> interpVecH15Packed = Convert<ComponentEnum::F8_E4M3FN, ComponentEnum::F16>(vecH15);
134+
135+ // CHECK: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC8M7N15U0S0.v4i32.v7f16(i32 -2147483622,
136+ // CHECK-SAME: %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]], i1 true, <4 x i32> %43, i32 21, <7 x half> %31, i32 8)
137+ // CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
138+ vector <half , 7 > vec11 = MultiplyAdd<half >(Mat_7_15, interpVecH15Packed, vecH7);
139+
140+ // CHECK: %[[LOAD3:.+]] = call %dx.types.ResRet.v7f16 @dx.op.rawBufferVectorLoad.v7f16(i32 303, %dx.types.Handle %45, i32 512, i32 undef, i32 2)
141+ // CHECK-SAME: ; RawBufferVectorLoad(buf,index,elementOffset,alignment)
142+ // CHECK-NEXT: %[[MEM_BIAS3:.*]] = extractvalue %dx.types.ResRet.v7f16 %46, 0
143+ // CHECK-NEXT: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC8M7N15U0S0.v4i32.v7f16(i32 -2147483622,
144+ // CHECK-SAME: %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]], i1 true, <4 x i32> %[[INTERP_VEC_H15_PACKED]], i32 21, <7 x half> %[[MEM_BIAS3]], i32 8)
145+ // CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
146+ vector <half , 7 > vec12 = MultiplyAdd<half >(Mat_7_15, interpVecH15Packed, memBias7);
90147}
0 commit comments