Skip to content

Commit e995e59

Browse files
authored
[Backport to 17] Extend INT4/FP4 packed conversions for i16, i64, and vector packed inputs (#3675) (#3712)
1 parent 411742c commit e995e59

3 files changed

Lines changed: 317 additions & 19 deletions

File tree

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4949,21 +4949,18 @@ processMiniFPOrInt4Type(Type *LLVMTy, FPEncodingWrap Encoding,
49494949
unsigned TyWidth = cast<IntegerType>(ScalarTy)->getBitWidth();
49504950
unsigned VecSize = 0;
49514951

4952-
if (TyWidth == 32) {
4953-
// Int4 or FP4 packed in 32-bit integer, change type and vector size.
4954-
assert((Encoding == FPEncodingWrap::E2M1 ||
4955-
Encoding == FPEncodingWrap::Integer) &&
4956-
"Unknown FP encoding");
4952+
bool IsPacked =
4953+
Encoding == FPEncodingWrap::E2M1 || Encoding == FPEncodingWrap::Integer;
4954+
if (IsPacked &&
4955+
(TyWidth == 8 || TyWidth == 16 || TyWidth == 32 || TyWidth == 64)) {
4956+
// Int4 or FP4 packed in an integer: each N-bit integer holds N/4 values.
49574957
assert(!isLLVMCooperativeMatrixType(LLVMTy) &&
49584958
"FP4 and Int4 matrices must not be packed");
4959-
VecSize = 8;
4960-
TyWidth = 4;
4961-
} else if (TyWidth == 8 && (Encoding == FPEncodingWrap::E2M1 ||
4962-
Encoding == FPEncodingWrap::Integer)) {
4963-
assert(!isLLVMCooperativeMatrixType(LLVMTy) &&
4964-
"FP4 and Int4 matrices must not be packed");
4965-
// Int4 or FP4 packed in 8-bit integer, change type and vector size.
4966-
VecSize = 2;
4959+
unsigned OuterVecLen =
4960+
LLVMTy->isVectorTy()
4961+
? cast<VectorType>(LLVMTy)->getElementCount().getFixedValue()
4962+
: 1;
4963+
VecSize = (TyWidth / 4) * OuterVecLen;
49674964
TyWidth = 4;
49684965
} else {
49694966
if (LLVMTy->isVectorTy())
@@ -5060,14 +5057,16 @@ SPIRVValue *LLVMToSPIRVBase::transDirectCallInst(CallInst *CI,
50605057
->getArgs());
50615058
SrcOp = BM->addUnaryInst(OpBitcast, SrcTy, SrcOp, BB);
50625059
} else if (FPDesc.SrcEncoding != FPEncodingWrap::Integer ||
5063-
(SrcTy->isTypeVector() && !LLVMSrcTy->isVectorTy())) {
5064-
// Create bitcast for FP4, FP8 and packed Int4.
5060+
SrcVecSize > 0) {
5061+
// Create bitcast for FP4, FP8 and packed Int4 (including cases where
5062+
// both the LLVM and SPIR-V types are vectors but with different
5063+
// sizes, e.g. <2 x i8> repacked as <4 x Int4>).
50655064
SrcOp = BM->addUnaryInst(OpBitcast, SrcTy, SrcOp, BB);
50665065
}
50675066
}
5067+
unsigned DstVecSize = 0;
50685068
if (!DstTy) {
50695069
// Dst type is 'mini' float or int4.
5070-
unsigned DstVecSize = 0;
50715070
DstTy = processMiniFPOrInt4Type(LLVMDstTy, FPDesc.DstEncoding,
50725071
GetScalarTy, BM, DstVecSize);
50735072

@@ -5095,10 +5094,9 @@ SPIRVValue *LLVMToSPIRVBase::transDirectCallInst(CallInst *CI,
50955094
if (FPDesc.DstEncoding == FPEncodingWrap::IEEE754 ||
50965095
FPDesc.DstEncoding == FPEncodingWrap::BF16)
50975096
return Conv;
5098-
// Originally not-packed integer.
5097+
// Originally not-packed integer (no repacking, or cooperative matrix).
50995098
if (FPDesc.DstEncoding == FPEncodingWrap::Integer &&
5100-
(DstTy->isTypeVector() == LLVMDstTy->isVectorTy() ||
5101-
isLLVMCooperativeMatrixType(LLVMDstTy)))
5099+
(DstVecSize == 0 || isLLVMCooperativeMatrixType(LLVMDstTy)))
51025100
return Conv;
51035101
// Need to adjust types: create bitcast for FP8 and packed Int4.
51045102
SPIRVValue *BitCast =

test/extensions/INTEL/SPV_INTEL_int4/conversions_packed.ll

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,15 @@
55
; 1. from packed Int4 to ... :
66
; a. packed in 32-bit
77
; b. packed in 8-bit
8+
; c. packed in 16-bit
9+
; d. packed in 64-bit
10+
; e. packed in vector of 8-bit integers
811
; 2. to packed Int4 from ... :
912
; a. packed in 32-bit
1013
; b. packed in 8-bit
14+
; c. packed in 16-bit
15+
; d. packed in 64-bit
16+
; e. packed in vector of 8-bit integers
1117

1218
; RUN: llvm-as %s -o %t.bc
1319
; RUN: llvm-spirv %t.bc -o %t.spv --spirv-ext=+SPV_EXT_float8,+SPV_INTEL_int4,+SPV_KHR_bfloat16
@@ -28,29 +34,50 @@
2834
; CHECK-SPIRV-DAG: Name [[#int4_e4m3_8:]] "int4_e4m3_8"
2935
; CHECK-SPIRV-DAG: Name [[#hf16_int4_32:]] "hf16_int4_32"
3036
; CHECK-SPIRV-DAG: Name [[#hf16_int4_8:]] "hf16_int4_8"
37+
; CHECK-SPIRV-DAG: Name [[#int4_e4m3_16:]] "int4_e4m3_16"
38+
; CHECK-SPIRV-DAG: Name [[#int4_e4m3_64:]] "int4_e4m3_64"
39+
; CHECK-SPIRV-DAG: Name [[#int4_e4m3_vec2xi8:]] "int4_e4m3_vec2xi8"
40+
; CHECK-SPIRV-DAG: Name [[#hf16_int4_16:]] "hf16_int4_16"
41+
; CHECK-SPIRV-DAG: Name [[#hf16_int4_64:]] "hf16_int4_64"
42+
; CHECK-SPIRV-DAG: Name [[#hf16_int4_vec2xi8:]] "hf16_int4_vec2xi8"
3143

3244
; CHECK-SPIRV-DAG: TypeInt [[#Int32Ty:]] 32 0
3345
; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Int32Const:]] 1
3446

3547
; CHECK-SPIRV-DAG: TypeInt [[#Int8Ty:]] 8 0
48+
; CHECK-SPIRV-DAG: TypeInt [[#Int16Ty:]] 16 0
49+
; CHECK-SPIRV-DAG: TypeInt [[#Int64Ty:]] 64 0
3650
; CHECK-SPIRV-DAG: TypeVector [[#Int8Vec8Ty:]] [[#Int8Ty]] 8
3751
; CHECK-SPIRV-DAG: TypeVector [[#Int8Vec2Ty:]] [[#Int8Ty]] 2
52+
; CHECK-SPIRV-DAG: TypeVector [[#Int8Vec4Ty:]] [[#Int8Ty]] 4
53+
; CHECK-SPIRV-DAG: TypeVector [[#Int8Vec16Ty:]] [[#Int8Ty]] 16
3854
; CHECK-SPIRV-DAG: Constant [[#Int8Ty]] [[#Int8Const:]] 1
55+
; CHECK-SPIRV-DAG: Constant [[#Int16Ty]] [[#Int16Const:]] 1
56+
; CHECK-SPIRV-DAG: Constant [[#Int64Ty]] [[#Int64Const:]] 1
57+
; CHECK-SPIRV-DAG: ConstantComposite [[#Int8Vec2Ty]] [[#Int8Vec2Const:]] [[#Int8Const]] [[#Int8Const]]
3958

4059
; CHECK-SPIRV-DAG: TypeInt [[#Int4Ty:]] 4 0
4160
; CHECK-SPIRV-DAG: TypeVector [[#Int4Vec8Ty:]] [[#Int4Ty]] 8
4261
; CHECK-SPIRV-DAG: TypeVector [[#Int4Vec2Ty:]] [[#Int4Ty]] 2
62+
; CHECK-SPIRV-DAG: TypeVector [[#Int4Vec4Ty:]] [[#Int4Ty]] 4
63+
; CHECK-SPIRV-DAG: TypeVector [[#Int4Vec16Ty:]] [[#Int4Ty]] 16
4364

4465
; CHECK-SPIRV-DAG: TypeFloat [[#Float8E4M3Ty:]] 8 4214
4566
; CHECK-SPIRV-DAG: TypeVector [[#Float8E4M3Vec8Ty:]] [[#Float8E4M3Ty]] 8
4667
; CHECK-SPIRV-DAG: TypeVector [[#Float8E4M3Vec2Ty:]] [[#Float8E4M3Ty]] 2
68+
; CHECK-SPIRV-DAG: TypeVector [[#Float8E4M3Vec4Ty:]] [[#Float8E4M3Ty]] 4
69+
; CHECK-SPIRV-DAG: TypeVector [[#Float8E4M3Vec16Ty:]] [[#Float8E4M3Ty]] 16
4770

4871
; CHECK-SPIRV-DAG: TypeFloat [[#HFloat16Ty:]] 16 {{$}}
4972
; CHECK-SPIRV-DAG: TypeVector [[#HFloat16Vec8Ty:]] [[#HFloat16Ty]] 8
5073
; CHECK-SPIRV-DAG: TypeVector [[#HFloat16Vec2Ty:]] [[#HFloat16Ty]] 2
74+
; CHECK-SPIRV-DAG: TypeVector [[#HFloat16Vec4Ty:]] [[#HFloat16Ty]] 4
75+
; CHECK-SPIRV-DAG: TypeVector [[#HFloat16Vec16Ty:]] [[#HFloat16Ty]] 16
5176
; CHECK-SPIRV-DAG: Constant [[#HFloat16Ty]] [[#HFloat16Const:]] 15360
5277
; CHECK-SPIRV-DAG: ConstantComposite [[#HFloat16Vec8Ty]] [[#HFloat16Vec8Const:]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]]
5378
; CHECK-SPIRV-DAG: ConstantComposite [[#HFloat16Vec2Ty]] [[#HFloat16Vec2Const:]] [[#HFloat16Const]] [[#HFloat16Const]]
79+
; CHECK-SPIRV-DAG: ConstantComposite [[#HFloat16Vec4Ty]] [[#HFloat16Vec4Const:]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]]
80+
; CHECK-SPIRV-DAG: ConstantComposite [[#HFloat16Vec16Ty]] [[#HFloat16Vec16Const:]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]]
5481

5582
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
5683
target triple = "spir-unknown-unknown"
@@ -136,3 +163,126 @@ entry:
136163
}
137164

138165
declare dso_local spir_func i8 @_Z38__builtin_spirv_ConvertFP16ToInt4INTELc(<2 x half>)
166+
167+
; Packed in 16-bit integer
168+
169+
; CHECK-SPIRV: Function [[#]] [[#int4_e4m3_16]] [[#]]
170+
; CHECK-SPIRV: Bitcast [[#Int4Vec4Ty]] [[#Cast1:]] [[#Int16Const]]
171+
; CHECK-SPIRV: ConvertSToF [[#Float8E4M3Vec4Ty]] [[#Conv:]] [[#Cast1]]
172+
; CHECK-SPIRV: Bitcast [[#Int8Vec4Ty]] [[#Cast2:]] [[#Conv]]
173+
; CHECK-SPIRV: ReturnValue [[#Cast2]]
174+
175+
; CHECK-LLVM-LABEL: int4_e4m3_16
176+
; CHECK-LLVM: %[[#Cast:]] = bitcast i16 1 to <4 x i4>
177+
; CHECK-LLVM: %[[#Call:]] = call <4 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELDv4_i(<4 x i4> %[[#Cast]])
178+
; CHECK-LLVM: ret <4 x i8> %[[#Call]]
179+
180+
define spir_func <4 x i8> @int4_e4m3_16() {
181+
entry:
182+
%0 = call spir_func <4 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELs(i16 1)
183+
ret <4 x i8> %0
184+
}
185+
186+
declare dso_local spir_func <4 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELs(i16)
187+
188+
; Packed in 64-bit integer
189+
190+
; CHECK-SPIRV: Function [[#]] [[#int4_e4m3_64]] [[#]]
191+
; CHECK-SPIRV: Bitcast [[#Int4Vec16Ty]] [[#Cast1:]] [[#Int64Const]]
192+
; CHECK-SPIRV: ConvertSToF [[#Float8E4M3Vec16Ty]] [[#Conv:]] [[#Cast1]]
193+
; CHECK-SPIRV: Bitcast [[#Int8Vec16Ty]] [[#Cast2:]] [[#Conv]]
194+
; CHECK-SPIRV: ReturnValue [[#Cast2]]
195+
196+
; CHECK-LLVM-LABEL: int4_e4m3_64
197+
; CHECK-LLVM: %[[#Cast:]] = bitcast i64 1 to <16 x i4>
198+
; CHECK-LLVM: %[[#Call:]] = call <16 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELDv16_i(<16 x i4> %[[#Cast]])
199+
; CHECK-LLVM: ret <16 x i8> %[[#Call]]
200+
201+
define spir_func <16 x i8> @int4_e4m3_64() {
202+
entry:
203+
%0 = call spir_func <16 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELl(i64 1)
204+
ret <16 x i8> %0
205+
}
206+
207+
declare dso_local spir_func <16 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELl(i64)
208+
209+
; Packed in vector of 8-bit integers
210+
211+
; CHECK-SPIRV: Function [[#]] [[#int4_e4m3_vec2xi8]] [[#]]
212+
; CHECK-SPIRV: Bitcast [[#Int4Vec4Ty]] [[#Cast1:]] [[#Int8Vec2Const]]
213+
; CHECK-SPIRV: ConvertSToF [[#Float8E4M3Vec4Ty]] [[#Conv:]] [[#Cast1]]
214+
; CHECK-SPIRV: Bitcast [[#Int8Vec4Ty]] [[#Cast2:]] [[#Conv]]
215+
; CHECK-SPIRV: ReturnValue [[#Cast2]]
216+
217+
; CHECK-LLVM-LABEL: int4_e4m3_vec2xi8
218+
; CHECK-LLVM: %[[#Cast:]] = bitcast <2 x i8> <i8 1, i8 1> to <4 x i4>
219+
; CHECK-LLVM: %[[#Call:]] = call <4 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELDv4_i(<4 x i4> %[[#Cast]])
220+
; CHECK-LLVM: ret <4 x i8> %[[#Call]]
221+
222+
define spir_func <4 x i8> @int4_e4m3_vec2xi8() {
223+
entry:
224+
%0 = call spir_func <4 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELDv4_i(<2 x i8> <i8 1, i8 1>)
225+
ret <4 x i8> %0
226+
}
227+
228+
declare dso_local spir_func <4 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELDv4_i(<2 x i8>)
229+
230+
; To packed in 16-bit integer
231+
232+
; CHECK-SPIRV: Function [[#]] [[#hf16_int4_16]] [[#]]
233+
; CHECK-SPIRV: ConvertFToS [[#Int4Vec4Ty]] [[#Conv:]] [[#HFloat16Vec4Const]]
234+
; CHECK-SPIRV: Bitcast [[#Int16Ty]] [[#Cast2:]] [[#Conv]]
235+
; CHECK-SPIRV: ReturnValue [[#Cast2]]
236+
237+
; CHECK-LLVM-LABEL: hf16_int4_16
238+
; CHECK-LLVM: %[[#Call:]] = call <4 x i4> @_Z38__builtin_spirv_ConvertFP16ToInt4INTELDv4_Dh(<4 x half> <half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00>)
239+
; CHECK-LLVM: %[[#Cast:]] = bitcast <4 x i4> %[[#Call]] to i16
240+
; CHECK-LLVM: ret i16 %[[#Cast]]
241+
242+
define spir_func i16 @hf16_int4_16() {
243+
entry:
244+
%0 = call i16 @_Z38__builtin_spirv_ConvertFP16ToInt4INTELDv4_Dh(<4 x half> <half 1.0, half 1.0, half 1.0, half 1.0>)
245+
ret i16 %0
246+
}
247+
248+
declare dso_local spir_func i16 @_Z38__builtin_spirv_ConvertFP16ToInt4INTELDv4_Dh(<4 x half>)
249+
250+
; To packed in 64-bit integer
251+
252+
; CHECK-SPIRV: Function [[#]] [[#hf16_int4_64]] [[#]]
253+
; CHECK-SPIRV: ConvertFToS [[#Int4Vec16Ty]] [[#Conv:]] [[#HFloat16Vec16Const]]
254+
; CHECK-SPIRV: Bitcast [[#Int64Ty]] [[#Cast2:]] [[#Conv]]
255+
; CHECK-SPIRV: ReturnValue [[#Cast2]]
256+
257+
; CHECK-LLVM-LABEL: hf16_int4_64
258+
; CHECK-LLVM: %[[#Call:]] = call <16 x i4> @_Z38__builtin_spirv_ConvertFP16ToInt4INTELDv16_Dh(<16 x half> <half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00>)
259+
; CHECK-LLVM: %[[#Cast:]] = bitcast <16 x i4> %[[#Call]] to i64
260+
; CHECK-LLVM: ret i64 %[[#Cast]]
261+
262+
define spir_func i64 @hf16_int4_64() {
263+
entry:
264+
%0 = call i64 @_Z38__builtin_spirv_ConvertFP16ToInt4INTELDv16_Dh(<16 x half> <half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0>)
265+
ret i64 %0
266+
}
267+
268+
declare dso_local spir_func i64 @_Z38__builtin_spirv_ConvertFP16ToInt4INTELDv16_Dh(<16 x half>)
269+
270+
; To packed in vector of 8-bit integers
271+
272+
; CHECK-SPIRV: Function [[#]] [[#hf16_int4_vec2xi8]] [[#]]
273+
; CHECK-SPIRV: ConvertFToS [[#Int4Vec4Ty]] [[#Conv:]] [[#HFloat16Vec4Const]]
274+
; CHECK-SPIRV: Bitcast [[#Int8Vec2Ty]] [[#Cast:]] [[#Conv]]
275+
; CHECK-SPIRV: ReturnValue [[#Cast]]
276+
277+
; CHECK-LLVM-LABEL: hf16_int4_vec2xi8
278+
; CHECK-LLVM: %[[#Call:]] = call <4 x i4> @_Z38__builtin_spirv_ConvertFP16ToInt4INTELDv4_Dh(<4 x half> <half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00>)
279+
; CHECK-LLVM: %[[#Cast:]] = bitcast <4 x i4> %[[#Call]] to <2 x i8>
280+
; CHECK-LLVM: ret <2 x i8> %[[#Cast]]
281+
282+
define spir_func <2 x i8> @hf16_int4_vec2xi8() {
283+
entry:
284+
%0 = call <2 x i8> @_Z38__builtin_spirv_ConvertFP16ToInt4INTELKDv4_Dh(<4 x half> <half 1.0, half 1.0, half 1.0, half 1.0>)
285+
ret <2 x i8> %0
286+
}
287+
288+
declare dso_local spir_func <2 x i8> @_Z38__builtin_spirv_ConvertFP16ToInt4INTELKDv4_Dh(<4 x half>)

0 commit comments

Comments
 (0)