|
5 | 5 | using namespace dx::linalg; |
6 | 6 |
|
7 | 7 | using MatrixATy = Matrix<ComponentType::F16, 8, 4, MatrixUse::A, MatrixScope::Thread>; |
8 | | -using MatrixAccumTy = Matrix<ComponentType::F16, 8, 8, MatrixUse::Accumulator, MatrixScope::Thread>; |
| 8 | +using MatrixAccum_8_8_Ty = Matrix<ComponentType::F16, 8, 8, MatrixUse::Accumulator, MatrixScope::Thread>; |
| 9 | +using MatrixAccum_8_4_Ty = Matrix<ComponentType::F16, 8, 4, MatrixUse::Accumulator, MatrixScope::Thread>; |
9 | 10 |
|
10 | 11 | ByteAddressBuffer BAB : register(t0); |
11 | 12 |
|
12 | 13 | [numthreads(4, 4, 4)] |
13 | 14 | void main(uint ID : SV_GroupID) { |
14 | 15 |
|
15 | 16 | // CHECK: %[[MAT1:.*]] = call %dx.types.LinAlgMatrixC8M8N4U0S0 @dx.op.linAlgMatrixLoadFromDescriptor.mC8M8N4U0S0( |
16 | | -// CHECK-SAME: i32 -2147483634, %dx.types.Handle %2, i32 0, i32 8, i32 1, i32 2) |
| 17 | +// CHECK-SAME: i32 -2147483634, %dx.types.Handle %{{[0-9]+}}, i32 0, i32 8, i32 1, i32 2) |
17 | 18 | // CHECK-SAME: ; LinAlgMatrixLoadFromDescriptor(handle,offset,stride,layout,align) |
18 | 19 | MatrixATy Mat1 = MatrixATy::Load<MatrixLayoutEnum::ColMajor>(BAB, 0, 8); |
19 | 20 |
|
20 | 21 | vector<half, 4> vec1 = 10.3f; |
21 | 22 |
|
22 | 23 | // CHECK: %[[VEC2:.*]] = call <8 x half> @dx.op.linAlgMatVecMul.v8f16.mC8M8N4U0S0.v4f16(i32 -2147483623, |
23 | | -// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N4U0S0 %3, i1 true, <4 x half> <half 0xH4926, half 0xH4926, half 0xH4926, |
| 24 | +// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N4U0S0 %[[MAT1]], i1 true, <4 x half> <half 0xH4926, half 0xH4926, half 0xH4926, |
24 | 25 | // CHECK-SAME: half 0xH4926>, i32 8) ; LinAlgMatVecMul(matrix,isOutputSigned,inputVector,interpretation) |
25 | 26 | vector<half, 8> vec2 = Multiply<half>(Mat1, vec1); |
26 | 27 |
|
@@ -61,13 +62,24 @@ void main(uint ID : SV_GroupID) { |
61 | 62 | // CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation) |
62 | 63 | vector<half, 8> vec6 = MultiplyAdd<half>(Mat1, interpVec2, memBias); |
63 | 64 |
|
64 | | - // CHECK: %[[ACCUM:.*]] = call %dx.types.LinAlgMatrixC8M8N8U2S0 |
| 65 | + // CHECK: %[[ACCUM1:.*]] = call %dx.types.LinAlgMatrixC8M8N8U2S0 |
65 | 66 | // CHECK-SAME: @dx.op.linAlgMatrixOuterProduct.mC8M8N8U2S0.v8f16.v8f16(i32 -2147483619, |
66 | 67 | // CHECK-SAME: <8 x half> %[[VEC5]], <8 x half> %[[VEC6]]) ; LinAlgMatrixOuterProduct(vectorA,vectorB) |
67 | | - MatrixAccumTy AccumMatrix = OuterProduct<ComponentType::F16>(vec5, vec6); |
| 68 | + MatrixAccum_8_8_Ty AccumMatrix1 = OuterProduct<ComponentType::F16>(vec5, vec6); |
| 69 | + |
| 70 | + // CHECK: %[[ACCUM2:.*]] = call %dx.types.LinAlgMatrixC8M8N4U2S0 @dx.op.linAlgMatrixOuterProduct.mC8M8N4U2S0.v8f16.v4f16( |
| 71 | + // CHECK-SAME: i32 -2147483619, <8 x half> %[[VEC5]], <4 x half> %[[VEC20]]) ; LinAlgMatrixOuterProduct(vectorA,vectorB) |
| 72 | + MatrixAccum_8_4_Ty AccumMatrix2 = OuterProduct<ComponentType::F16>(vec5, vec20); |
68 | 73 |
|
69 | 74 | // CHECK: %[[CONV_VEC:.*]] = call <8 x float> @dx.op.linAlgConvert.v8f32.v8f16(i32 -2147483618, |
70 | 75 | // CHECK-SAME: <8 x half> %[[VEC6]], i32 8, i32 9) ; LinAlgConvert(inputVector,inputInterpretation,outputInterpretation) |
71 | 76 | InterpretedVector<float, 8, ComponentType::F32> convertedVec; |
72 | 77 | convertedVec = Convert<ComponentType::F32, ComponentType::F16>(vec6); |
| 78 | + |
| 79 | + // CHECK: call <4 x i32> @dx.op.linAlgConvert.v4i32.v16f16(i32 -2147483618, <16 x half> %21, i32 8, i32 21) |
| 80 | + // CHECK: ; LinAlgConvert(inputVector,inputInterpretation,outputInterpretation) |
| 81 | + typedef vector<half, 16> half16; |
| 82 | + half16 srcF16 = BAB.Load<half16>(128); |
| 83 | + InterpretedVector<uint, 4, ComponentEnum::F8_E4M3FN> convertedPacked = Convert<ComponentEnum::F8_E4M3FN, ComponentEnum::F16>(srcF16); |
| 84 | + |
73 | 85 | } |
0 commit comments