Skip to content

Commit e310fd6

Browse files
committed
Add test with variations for matVecMulAdd
1 parent 6eb36e4 commit e310fd6

1 file changed

Lines changed: 87 additions & 0 deletions

File tree

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=float16_t -DIU=0 -DITY=float16_t -DII=F16 -DMI=F16 -DML=RowMajor -DMT=0 -DBI=F16 | FileCheck %s --check-prefixes DXIL,DXIL-0
2+
// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=float16_t -DIU=0 -DITY=float16_t -DII=F8_E4M3 -DMI=F8_E4M3 -DML=ColumnMajor -DMT=1 -DBI=F16 | FileCheck %s --check-prefixes DXIL,DXIL-1
3+
// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=float16_t -DIU=0 -DITY=float16_t -DII=F8_E5M2 -DMI=F8_E5M2 -DML=MulOptimal -DMT=0 -DBI=F16 | FileCheck %s --check-prefixes DXIL,DXIL-2
4+
// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=int -DIU=0 -DITY=int -DII=I8 -DMI=I8 -DML=OuterProductOptimal -DMT=1 -DBI=I32 | FileCheck %s --check-prefixes DXIL,DXIL-3
5+
// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=int -DIU=0 -DITY=float -DII=I8 -DMI=I8 -DML=RowMajor -DMT=0 -DBI=I32 | FileCheck %s --check-prefixes DXIL,DXIL-4
6+
// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=1 -DOTY=uint -DIU=0 -DITY=float -DII=I8 -DMI=F16 -DML=RowMajor -DMT=0 -DBI=I8 | FileCheck %s --check-prefixes DXIL,DXIL-5
7+
// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=int -DIU=1 -DITY=uint -DII=U8 -DMI=I8 -DML=ColumnMajor -DMT=1 -DBI=I8 | FileCheck %s --check-prefixes DXIL,DXIL-6
8+
// RUN: %dxc -T cs_6_9 %s -enable-16bit-types -DOU=0 -DOTY=int -DIU=0 -DITY=int -DII=U8 -DMI=U8 -DML=MulOptimal -DMT=0 -DBI=I8 | FileCheck %s --check-prefixes DXIL,DXIL-7
9+
10+
// Test minimum support set of combinations for matVecMul
11+
// DXIL: define void @main()
12+
// DXIL-0: call <4 x half> @dx.op.matVecMulAdd.v4f16.v8f16(i32 306, <8 x half> {{[^ ]+}}, i1 false, i32 8, %dx.types.Handle {{[^ ]+}}, i32 0, i32 8, i32 8, i32 8, i32 0, i1 false, i32 64, %dx.types.Handle {{[^ ]+}}, i32 0, i32 8, i1 false) ; MatVecMulAdd(inputVector,isInputUnsigned,inputInterpretation,matrixBuffer,matrixOffset,matrixIntepretation,matrixM,matrixK,matrixLayout,matrixTranspose,matrixStride,biasBuffer,biasOffset,biasIntepretation,isOutputUnsigned)
13+
// DXIL-1: call <4 x half> @dx.op.matVecMulAdd.v4f16.v8f16(i32 306, <8 x half> {{[^ ]+}}, i1 false, i32 21, %dx.types.Handle {{[^ ]+}}, i32 0, i32 21, i32 8, i32 8, i32 1, i1 true, i32 64, %dx.types.Handle {{[^ ]+}}, i32 0, i32 8, i1 false) ; MatVecMulAdd(inputVector,isInputUnsigned,inputInterpretation,matrixBuffer,matrixOffset,matrixIntepretation,matrixM,matrixK,matrixLayout,matrixTranspose,matrixStride,biasBuffer,biasOffset,biasIntepretation,isOutputUnsigned)
14+
// DXIL-2: call <4 x half> @dx.op.matVecMulAdd.v4f16.v8f16(i32 306, <8 x half> {{[^ ]+}}, i1 false, i32 22, %dx.types.Handle {{[^ ]+}}, i32 0, i32 22, i32 8, i32 8, i32 2, i1 false, i32 64, %dx.types.Handle {{[^ ]+}}, i32 0, i32 8, i1 false) ; MatVecMulAdd(inputVector,isInputUnsigned,inputInterpretation,matrixBuffer,matrixOffset,matrixIntepretation,matrixM,matrixK,matrixLayout,matrixTranspose,matrixStride,biasBuffer,biasOffset,biasIntepretation,isOutputUnsigned)
15+
// DXIL-3: call <4 x i32> @dx.op.matVecMulAdd.v4i32.v8i32(i32 306, <8 x i32> {{[^ ]+}}, i1 false, i32 20, %dx.types.Handle {{[^ ]+}}, i32 0, i32 20, i32 8, i32 8, i32 3, i1 true, i32 64, %dx.types.Handle {{[^ ]+}}, i32 0, i32 4, i1 false) ; MatVecMulAdd(inputVector,isInputUnsigned,inputInterpretation,matrixBuffer,matrixOffset,matrixIntepretation,matrixM,matrixK,matrixLayout,matrixTranspose,matrixStride,biasBuffer,biasOffset,biasIntepretation,isOutputUnsigned)
16+
// DXIL-4: call <4 x i32> @dx.op.matVecMulAdd.v4i32.v8f32(i32 306, <8 x float> {{[^ ]+}}, i1 false, i32 20, %dx.types.Handle {{[^ ]+}}, i32 0, i32 20, i32 8, i32 8, i32 0, i1 false, i32 64, %dx.types.Handle {{[^ ]+}}, i32 0, i32 4, i1 false) ; MatVecMulAdd(inputVector,isInputUnsigned,inputInterpretation,matrixBuffer,matrixOffset,matrixIntepretation,matrixM,matrixK,matrixLayout,matrixTranspose,matrixStride,biasBuffer,biasOffset,biasIntepretation,isOutputUnsigned)
17+
18+
// Test unsigned variations
19+
// DXIL-5: call <4 x i32> @dx.op.matVecMulAdd.v4i32.v8f32(i32 306, <8 x float> {{[^ ]+}}, i1 false, i32 20, %dx.types.Handle {{[^ ]+}}, i32 0, i32 8, i32 8, i32 8, i32 0, i1 false, i32 64, %dx.types.Handle {{[^ ]+}}, i32 0, i32 20, i1 true) ; MatVecMulAdd(inputVector,isInputUnsigned,inputInterpretation,matrixBuffer,matrixOffset,matrixIntepretation,matrixM,matrixK,matrixLayout,matrixTranspose,matrixStride,biasBuffer,biasOffset,biasIntepretation,isOutputUnsigned)
20+
// DXIL-6: call <4 x i32> @dx.op.matVecMulAdd.v4i32.v8i32(i32 306, <8 x i32> {{[^ ]+}}, i1 true, i32 19, %dx.types.Handle {{[^ ]+}}, i32 0, i32 20, i32 8, i32 8, i32 1, i1 true, i32 64, %dx.types.Handle {{[^ ]+}}, i32 0, i32 20, i1 false) ; MatVecMulAdd(inputVector,isInputUnsigned,inputInterpretation,matrixBuffer,matrixOffset,matrixIntepretation,matrixM,matrixK,matrixLayout,matrixTranspose,matrixStride,biasBuffer,biasOffset,biasIntepretation,isOutputUnsigned)
21+
// DXIL-7: call <4 x i32> @dx.op.matVecMulAdd.v4i32.v8i32(i32 306, <8 x i32> {{[^ ]+}}, i1 false, i32 19, %dx.types.Handle {{[^ ]+}}, i32 0, i32 19, i32 8, i32 8, i32 2, i1 false, i32 64, %dx.types.Handle {{[^ ]+}}, i32 0, i32 20, i1 false) ; MatVecMulAdd(inputVector,isInputUnsigned,inputInterpretation,matrixBuffer,matrixOffset,matrixIntepretation,matrixM,matrixK,matrixLayout,matrixTranspose,matrixStride,biasBuffer,biasOffset,biasIntepretation,isOutputUnsigned)
22+
23+
24+
ByteAddressBuffer input_vector_buffer;
25+
ByteAddressBuffer matrix_buffer;
26+
ByteAddressBuffer bias_buffer;
27+
RWByteAddressBuffer rw_matrix_buffer;
28+
29+
enum CompType {
30+
Invalid = 0,
31+
I1 = 1,
32+
I16 = 2,
33+
U16 = 3,
34+
I32 = 4,
35+
U32 = 5,
36+
I64 = 6,
37+
U64 = 7,
38+
F16 = 8,
39+
F32 = 9,
40+
F64 = 10,
41+
SNormF16 = 11,
42+
UNormF16 = 12,
43+
SNormF32 = 13,
44+
UNormF32 = 14,
45+
SNormF64 = 15,
46+
UNormF64 = 16,
47+
PackedS8x32 = 17,
48+
PackedU8x32 = 18,
49+
50+
// BEGIN NEW FOR SM 6.9
51+
U8 = 19,
52+
I8 = 20,
53+
F8_E4M3 = 21,
54+
F8_E5M2 = 22,
55+
};
56+
57+
enum MatLayout {
58+
RowMajor = 0,
59+
ColumnMajor = 1,
60+
MulOptimal = 2,
61+
OuterProductOptimal = 3,
62+
};
63+
64+
[NumThreads(1,1,1)]
65+
void main()
66+
{
67+
vector<OTY, 4> output_vector;
68+
static const uint is_output_unsigned = OU;
69+
70+
vector<ITY, 8> input_vector = input_vector_buffer.Load<vector<ITY, 8> >(0);
71+
const uint is_input_unsigned = IU;
72+
const uint input_interpretation = II;
73+
74+
const uint matrix_offset = 0;
75+
const uint matrix_interpretation = MI;
76+
const uint matrix_dimM = 8;
77+
const uint matrix_dimK = 8;
78+
const uint matrix_layout = ML;
79+
const bool matrix_is_transposed = (bool) MT;
80+
const uint matrix_stride = 64;
81+
82+
const uint bias_offset = 0;
83+
const uint bias_interpretation = BI;
84+
85+
__builtin_MatVecMulAdd(output_vector, is_output_unsigned, input_vector, is_input_unsigned, input_interpretation, matrix_buffer, matrix_offset, matrix_interpretation,
86+
matrix_dimM, matrix_dimK, matrix_layout, matrix_is_transposed, matrix_stride, bias_buffer, bias_offset, bias_interpretation);
87+
}

0 commit comments

Comments
 (0)