Skip to content

Commit e45bd3a

Browse files
authored
[Backport to 19] Extend INT4/FP4 packed conversions for i16, i64, and vector packed inputs (#3675) (#3710)
1 parent 2256097 commit e45bd3a

3 files changed

Lines changed: 317 additions & 19 deletions

File tree

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5317,21 +5317,18 @@ processMiniFPOrInt4Type(Type *LLVMTy, FPEncodingWrap Encoding,
53175317
unsigned TyWidth = cast<IntegerType>(ScalarTy)->getBitWidth();
53185318
unsigned VecSize = 0;
53195319

5320-
if (TyWidth == 32) {
5321-
// Int4 or FP4 packed in 32-bit integer, change type and vector size.
5322-
assert((Encoding == FPEncodingWrap::E2M1 ||
5323-
Encoding == FPEncodingWrap::Integer) &&
5324-
"Unknown FP encoding");
5320+
bool IsPacked =
5321+
Encoding == FPEncodingWrap::E2M1 || Encoding == FPEncodingWrap::Integer;
5322+
if (IsPacked &&
5323+
(TyWidth == 8 || TyWidth == 16 || TyWidth == 32 || TyWidth == 64)) {
5324+
// Int4 or FP4 packed in an integer: each N-bit integer holds N/4 values.
53255325
assert(!isLLVMCooperativeMatrixType(LLVMTy) &&
53265326
"FP4 and Int4 matrices must not be packed");
5327-
VecSize = 8;
5328-
TyWidth = 4;
5329-
} else if (TyWidth == 8 && (Encoding == FPEncodingWrap::E2M1 ||
5330-
Encoding == FPEncodingWrap::Integer)) {
5331-
assert(!isLLVMCooperativeMatrixType(LLVMTy) &&
5332-
"FP4 and Int4 matrices must not be packed");
5333-
// Int4 or FP4 packed in 8-bit integer, change type and vector size.
5334-
VecSize = 2;
5327+
unsigned OuterVecLen =
5328+
LLVMTy->isVectorTy()
5329+
? cast<VectorType>(LLVMTy)->getElementCount().getFixedValue()
5330+
: 1;
5331+
VecSize = (TyWidth / 4) * OuterVecLen;
53355332
TyWidth = 4;
53365333
} else {
53375334
if (LLVMTy->isVectorTy())
@@ -5427,14 +5424,16 @@ SPIRVValue *LLVMToSPIRVBase::transDirectCallInst(CallInst *CI,
54275424
->getArgs());
54285425
SrcOp = BM->addUnaryInst(OpBitcast, SrcTy, SrcOp, BB);
54295426
} else if (FPDesc.SrcEncoding != FPEncodingWrap::Integer ||
5430-
(SrcTy->isTypeVector() && !LLVMSrcTy->isVectorTy())) {
5431-
// Create bitcast for FP4, FP8 and packed Int4.
5427+
SrcVecSize > 0) {
5428+
// Create bitcast for FP4, FP8 and packed Int4 (including cases where
5429+
// both the LLVM and SPIR-V types are vectors but with different
5430+
// sizes, e.g. <2 x i8> repacked as <4 x Int4>).
54325431
SrcOp = BM->addUnaryInst(OpBitcast, SrcTy, SrcOp, BB);
54335432
}
54345433
}
5434+
unsigned DstVecSize = 0;
54355435
if (!DstTy) {
54365436
// Dst type is 'mini' float or int4.
5437-
unsigned DstVecSize = 0;
54385437
DstTy = processMiniFPOrInt4Type(LLVMDstTy, FPDesc.DstEncoding,
54395438
GetScalarTy, BM, DstVecSize);
54405439

@@ -5462,10 +5461,9 @@ SPIRVValue *LLVMToSPIRVBase::transDirectCallInst(CallInst *CI,
54625461
if (FPDesc.DstEncoding == FPEncodingWrap::IEEE754 ||
54635462
FPDesc.DstEncoding == FPEncodingWrap::BF16)
54645463
return Conv;
5465-
// Originally not-packed integer.
5464+
// Originally not-packed integer (no repacking, or cooperative matrix).
54665465
if (FPDesc.DstEncoding == FPEncodingWrap::Integer &&
5467-
(DstTy->isTypeVector() == LLVMDstTy->isVectorTy() ||
5468-
isLLVMCooperativeMatrixType(LLVMDstTy)))
5466+
(DstVecSize == 0 || isLLVMCooperativeMatrixType(LLVMDstTy)))
54695467
return Conv;
54705468
// Need to adjust types: create bitcast for FP8 and packed Int4.
54715469
SPIRVValue *BitCast =

test/extensions/INTEL/SPV_INTEL_int4/conversions_packed.ll

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,15 @@
55
; 1. from packed Int4 to ... :
66
; a. packed in 32-bit
77
; b. packed in 8-bit
8+
; c. packed in 16-bit
9+
; d. packed in 64-bit
10+
; e. packed in vector of 8-bit integers
811
; 2. to packed Int4 from ... :
912
; a. packed in 32-bit
1013
; b. packed in 8-bit
14+
; c. packed in 16-bit
15+
; d. packed in 64-bit
16+
; e. packed in vector of 8-bit integers
1117

1218
; RUN: llvm-as %s -o %t.bc
1319
; RUN: llvm-spirv %t.bc -o %t.spv --spirv-ext=+SPV_EXT_float8,+SPV_INTEL_int4,+SPV_KHR_bfloat16
@@ -28,29 +34,50 @@
2834
; CHECK-SPIRV-DAG: Name [[#int4_e4m3_8:]] "int4_e4m3_8"
2935
; CHECK-SPIRV-DAG: Name [[#hf16_int4_32:]] "hf16_int4_32"
3036
; CHECK-SPIRV-DAG: Name [[#hf16_int4_8:]] "hf16_int4_8"
37+
; CHECK-SPIRV-DAG: Name [[#int4_e4m3_16:]] "int4_e4m3_16"
38+
; CHECK-SPIRV-DAG: Name [[#int4_e4m3_64:]] "int4_e4m3_64"
39+
; CHECK-SPIRV-DAG: Name [[#int4_e4m3_vec2xi8:]] "int4_e4m3_vec2xi8"
40+
; CHECK-SPIRV-DAG: Name [[#hf16_int4_16:]] "hf16_int4_16"
41+
; CHECK-SPIRV-DAG: Name [[#hf16_int4_64:]] "hf16_int4_64"
42+
; CHECK-SPIRV-DAG: Name [[#hf16_int4_vec2xi8:]] "hf16_int4_vec2xi8"
3143

3244
; CHECK-SPIRV-DAG: TypeInt [[#Int32Ty:]] 32 0
3345
; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Int32Const:]] 1
3446

3547
; CHECK-SPIRV-DAG: TypeInt [[#Int8Ty:]] 8 0
48+
; CHECK-SPIRV-DAG: TypeInt [[#Int16Ty:]] 16 0
49+
; CHECK-SPIRV-DAG: TypeInt [[#Int64Ty:]] 64 0
3650
; CHECK-SPIRV-DAG: TypeVector [[#Int8Vec8Ty:]] [[#Int8Ty]] 8
3751
; CHECK-SPIRV-DAG: TypeVector [[#Int8Vec2Ty:]] [[#Int8Ty]] 2
52+
; CHECK-SPIRV-DAG: TypeVector [[#Int8Vec4Ty:]] [[#Int8Ty]] 4
53+
; CHECK-SPIRV-DAG: TypeVector [[#Int8Vec16Ty:]] [[#Int8Ty]] 16
3854
; CHECK-SPIRV-DAG: Constant [[#Int8Ty]] [[#Int8Const:]] 1
55+
; CHECK-SPIRV-DAG: Constant [[#Int16Ty]] [[#Int16Const:]] 1
56+
; CHECK-SPIRV-DAG: Constant [[#Int64Ty]] [[#Int64Const:]] 1
57+
; CHECK-SPIRV-DAG: ConstantComposite [[#Int8Vec2Ty]] [[#Int8Vec2Const:]] [[#Int8Const]] [[#Int8Const]]
3958

4059
; CHECK-SPIRV-DAG: TypeInt [[#Int4Ty:]] 4 0
4160
; CHECK-SPIRV-DAG: TypeVector [[#Int4Vec8Ty:]] [[#Int4Ty]] 8
4261
; CHECK-SPIRV-DAG: TypeVector [[#Int4Vec2Ty:]] [[#Int4Ty]] 2
62+
; CHECK-SPIRV-DAG: TypeVector [[#Int4Vec4Ty:]] [[#Int4Ty]] 4
63+
; CHECK-SPIRV-DAG: TypeVector [[#Int4Vec16Ty:]] [[#Int4Ty]] 16
4364

4465
; CHECK-SPIRV-DAG: TypeFloat [[#Float8E4M3Ty:]] 8 4214
4566
; CHECK-SPIRV-DAG: TypeVector [[#Float8E4M3Vec8Ty:]] [[#Float8E4M3Ty]] 8
4667
; CHECK-SPIRV-DAG: TypeVector [[#Float8E4M3Vec2Ty:]] [[#Float8E4M3Ty]] 2
68+
; CHECK-SPIRV-DAG: TypeVector [[#Float8E4M3Vec4Ty:]] [[#Float8E4M3Ty]] 4
69+
; CHECK-SPIRV-DAG: TypeVector [[#Float8E4M3Vec16Ty:]] [[#Float8E4M3Ty]] 16
4770

4871
; CHECK-SPIRV-DAG: TypeFloat [[#HFloat16Ty:]] 16 {{$}}
4972
; CHECK-SPIRV-DAG: TypeVector [[#HFloat16Vec8Ty:]] [[#HFloat16Ty]] 8
5073
; CHECK-SPIRV-DAG: TypeVector [[#HFloat16Vec2Ty:]] [[#HFloat16Ty]] 2
74+
; CHECK-SPIRV-DAG: TypeVector [[#HFloat16Vec4Ty:]] [[#HFloat16Ty]] 4
75+
; CHECK-SPIRV-DAG: TypeVector [[#HFloat16Vec16Ty:]] [[#HFloat16Ty]] 16
5176
; CHECK-SPIRV-DAG: Constant [[#HFloat16Ty]] [[#HFloat16Const:]] 15360
5277
; CHECK-SPIRV-DAG: ConstantComposite [[#HFloat16Vec8Ty]] [[#HFloat16Vec8Const:]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]]
5378
; CHECK-SPIRV-DAG: ConstantComposite [[#HFloat16Vec2Ty]] [[#HFloat16Vec2Const:]] [[#HFloat16Const]] [[#HFloat16Const]]
79+
; CHECK-SPIRV-DAG: ConstantComposite [[#HFloat16Vec4Ty]] [[#HFloat16Vec4Const:]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]]
80+
; CHECK-SPIRV-DAG: ConstantComposite [[#HFloat16Vec16Ty]] [[#HFloat16Vec16Const:]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]]
5481

5582
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
5683
target triple = "spir-unknown-unknown"
@@ -136,3 +163,126 @@ entry:
136163
}
137164

138165
declare dso_local spir_func i8 @_Z38__builtin_spirv_ConvertFP16ToInt4INTELc(<2 x half>)
166+
167+
; Packed in 16-bit integer
168+
169+
; CHECK-SPIRV: Function [[#]] [[#int4_e4m3_16]] [[#]]
170+
; CHECK-SPIRV: Bitcast [[#Int4Vec4Ty]] [[#Cast1:]] [[#Int16Const]]
171+
; CHECK-SPIRV: ConvertSToF [[#Float8E4M3Vec4Ty]] [[#Conv:]] [[#Cast1]]
172+
; CHECK-SPIRV: Bitcast [[#Int8Vec4Ty]] [[#Cast2:]] [[#Conv]]
173+
; CHECK-SPIRV: ReturnValue [[#Cast2]]
174+
175+
; CHECK-LLVM-LABEL: int4_e4m3_16
176+
; CHECK-LLVM: %[[#Cast:]] = bitcast i16 1 to <4 x i4>
177+
; CHECK-LLVM: %[[#Call:]] = call <4 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELDv4_i(<4 x i4> %[[#Cast]])
178+
; CHECK-LLVM: ret <4 x i8> %[[#Call]]
179+
180+
define spir_func <4 x i8> @int4_e4m3_16() {
181+
entry:
182+
%0 = call spir_func <4 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELs(i16 1)
183+
ret <4 x i8> %0
184+
}
185+
186+
declare dso_local spir_func <4 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELs(i16)
187+
188+
; Packed in 64-bit integer
189+
190+
; CHECK-SPIRV: Function [[#]] [[#int4_e4m3_64]] [[#]]
191+
; CHECK-SPIRV: Bitcast [[#Int4Vec16Ty]] [[#Cast1:]] [[#Int64Const]]
192+
; CHECK-SPIRV: ConvertSToF [[#Float8E4M3Vec16Ty]] [[#Conv:]] [[#Cast1]]
193+
; CHECK-SPIRV: Bitcast [[#Int8Vec16Ty]] [[#Cast2:]] [[#Conv]]
194+
; CHECK-SPIRV: ReturnValue [[#Cast2]]
195+
196+
; CHECK-LLVM-LABEL: int4_e4m3_64
197+
; CHECK-LLVM: %[[#Cast:]] = bitcast i64 1 to <16 x i4>
198+
; CHECK-LLVM: %[[#Call:]] = call <16 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELDv16_i(<16 x i4> %[[#Cast]])
199+
; CHECK-LLVM: ret <16 x i8> %[[#Call]]
200+
201+
define spir_func <16 x i8> @int4_e4m3_64() {
202+
entry:
203+
%0 = call spir_func <16 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELl(i64 1)
204+
ret <16 x i8> %0
205+
}
206+
207+
declare dso_local spir_func <16 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELl(i64)
208+
209+
; Packed in vector of 8-bit integers
210+
211+
; CHECK-SPIRV: Function [[#]] [[#int4_e4m3_vec2xi8]] [[#]]
212+
; CHECK-SPIRV: Bitcast [[#Int4Vec4Ty]] [[#Cast1:]] [[#Int8Vec2Const]]
213+
; CHECK-SPIRV: ConvertSToF [[#Float8E4M3Vec4Ty]] [[#Conv:]] [[#Cast1]]
214+
; CHECK-SPIRV: Bitcast [[#Int8Vec4Ty]] [[#Cast2:]] [[#Conv]]
215+
; CHECK-SPIRV: ReturnValue [[#Cast2]]
216+
217+
; CHECK-LLVM-LABEL: int4_e4m3_vec2xi8
218+
; CHECK-LLVM: %[[#Cast:]] = bitcast <2 x i8> <i8 1, i8 1> to <4 x i4>
219+
; CHECK-LLVM: %[[#Call:]] = call <4 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELDv4_i(<4 x i4> %[[#Cast]])
220+
; CHECK-LLVM: ret <4 x i8> %[[#Call]]
221+
222+
define spir_func <4 x i8> @int4_e4m3_vec2xi8() {
223+
entry:
224+
%0 = call spir_func <4 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELDv4_i(<2 x i8> <i8 1, i8 1>)
225+
ret <4 x i8> %0
226+
}
227+
228+
declare dso_local spir_func <4 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELDv4_i(<2 x i8>)
229+
230+
; To packed in 16-bit integer
231+
232+
; CHECK-SPIRV: Function [[#]] [[#hf16_int4_16]] [[#]]
233+
; CHECK-SPIRV: ConvertFToS [[#Int4Vec4Ty]] [[#Conv:]] [[#HFloat16Vec4Const]]
234+
; CHECK-SPIRV: Bitcast [[#Int16Ty]] [[#Cast2:]] [[#Conv]]
235+
; CHECK-SPIRV: ReturnValue [[#Cast2]]
236+
237+
; CHECK-LLVM-LABEL: hf16_int4_16
238+
; CHECK-LLVM: %[[#Call:]] = call <4 x i4> @_Z38__builtin_spirv_ConvertFP16ToInt4INTELDv4_Dh(<4 x half> <half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00>)
239+
; CHECK-LLVM: %[[#Cast:]] = bitcast <4 x i4> %[[#Call]] to i16
240+
; CHECK-LLVM: ret i16 %[[#Cast]]
241+
242+
define spir_func i16 @hf16_int4_16() {
243+
entry:
244+
%0 = call i16 @_Z38__builtin_spirv_ConvertFP16ToInt4INTELDv4_Dh(<4 x half> <half 1.0, half 1.0, half 1.0, half 1.0>)
245+
ret i16 %0
246+
}
247+
248+
declare dso_local spir_func i16 @_Z38__builtin_spirv_ConvertFP16ToInt4INTELDv4_Dh(<4 x half>)
249+
250+
; To packed in 64-bit integer
251+
252+
; CHECK-SPIRV: Function [[#]] [[#hf16_int4_64]] [[#]]
253+
; CHECK-SPIRV: ConvertFToS [[#Int4Vec16Ty]] [[#Conv:]] [[#HFloat16Vec16Const]]
254+
; CHECK-SPIRV: Bitcast [[#Int64Ty]] [[#Cast2:]] [[#Conv]]
255+
; CHECK-SPIRV: ReturnValue [[#Cast2]]
256+
257+
; CHECK-LLVM-LABEL: hf16_int4_64
258+
; CHECK-LLVM: %[[#Call:]] = call <16 x i4> @_Z38__builtin_spirv_ConvertFP16ToInt4INTELDv16_Dh(<16 x half> <half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00>)
259+
; CHECK-LLVM: %[[#Cast:]] = bitcast <16 x i4> %[[#Call]] to i64
260+
; CHECK-LLVM: ret i64 %[[#Cast]]
261+
262+
define spir_func i64 @hf16_int4_64() {
263+
entry:
264+
%0 = call i64 @_Z38__builtin_spirv_ConvertFP16ToInt4INTELDv16_Dh(<16 x half> <half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0>)
265+
ret i64 %0
266+
}
267+
268+
declare dso_local spir_func i64 @_Z38__builtin_spirv_ConvertFP16ToInt4INTELDv16_Dh(<16 x half>)
269+
270+
; To packed in vector of 8-bit integers
271+
272+
; CHECK-SPIRV: Function [[#]] [[#hf16_int4_vec2xi8]] [[#]]
273+
; CHECK-SPIRV: ConvertFToS [[#Int4Vec4Ty]] [[#Conv:]] [[#HFloat16Vec4Const]]
274+
; CHECK-SPIRV: Bitcast [[#Int8Vec2Ty]] [[#Cast:]] [[#Conv]]
275+
; CHECK-SPIRV: ReturnValue [[#Cast]]
276+
277+
; CHECK-LLVM-LABEL: hf16_int4_vec2xi8
278+
; CHECK-LLVM: %[[#Call:]] = call <4 x i4> @_Z38__builtin_spirv_ConvertFP16ToInt4INTELDv4_Dh(<4 x half> <half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00>)
279+
; CHECK-LLVM: %[[#Cast:]] = bitcast <4 x i4> %[[#Call]] to <2 x i8>
280+
; CHECK-LLVM: ret <2 x i8> %[[#Cast]]
281+
282+
define spir_func <2 x i8> @hf16_int4_vec2xi8() {
283+
entry:
284+
%0 = call <2 x i8> @_Z38__builtin_spirv_ConvertFP16ToInt4INTELKDv4_Dh(<4 x half> <half 1.0, half 1.0, half 1.0, half 1.0>)
285+
ret <2 x i8> %0
286+
}
287+
288+
declare dso_local spir_func <2 x i8> @_Z38__builtin_spirv_ConvertFP16ToInt4INTELKDv4_Dh(<4 x half>)

0 commit comments

Comments
 (0)