Skip to content

Commit 8c22642

Browse files
authored
[Backport to 18] Extend INT4/FP4 packed conversions for i16, i64, and vector packed inputs (#3675) (#3711)
1 parent bff9650 commit 8c22642

3 files changed

Lines changed: 317 additions & 19 deletions

File tree

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5233,21 +5233,18 @@ processMiniFPOrInt4Type(Type *LLVMTy, FPEncodingWrap Encoding,
52335233
unsigned TyWidth = cast<IntegerType>(ScalarTy)->getBitWidth();
52345234
unsigned VecSize = 0;
52355235

5236-
if (TyWidth == 32) {
5237-
// Int4 or FP4 packed in 32-bit integer, change type and vector size.
5238-
assert((Encoding == FPEncodingWrap::E2M1 ||
5239-
Encoding == FPEncodingWrap::Integer) &&
5240-
"Unknown FP encoding");
5236+
bool IsPacked =
5237+
Encoding == FPEncodingWrap::E2M1 || Encoding == FPEncodingWrap::Integer;
5238+
if (IsPacked &&
5239+
(TyWidth == 8 || TyWidth == 16 || TyWidth == 32 || TyWidth == 64)) {
5240+
// Int4 or FP4 packed in an integer: each N-bit integer holds N/4 values.
52415241
assert(!isLLVMCooperativeMatrixType(LLVMTy) &&
52425242
"FP4 and Int4 matrices must not be packed");
5243-
VecSize = 8;
5244-
TyWidth = 4;
5245-
} else if (TyWidth == 8 && (Encoding == FPEncodingWrap::E2M1 ||
5246-
Encoding == FPEncodingWrap::Integer)) {
5247-
assert(!isLLVMCooperativeMatrixType(LLVMTy) &&
5248-
"FP4 and Int4 matrices must not be packed");
5249-
// Int4 or FP4 packed in 8-bit integer, change type and vector size.
5250-
VecSize = 2;
5243+
unsigned OuterVecLen =
5244+
LLVMTy->isVectorTy()
5245+
? cast<VectorType>(LLVMTy)->getElementCount().getFixedValue()
5246+
: 1;
5247+
VecSize = (TyWidth / 4) * OuterVecLen;
52515248
TyWidth = 4;
52525249
} else {
52535250
if (LLVMTy->isVectorTy())
@@ -5343,14 +5340,16 @@ SPIRVValue *LLVMToSPIRVBase::transDirectCallInst(CallInst *CI,
53435340
->getArgs());
53445341
SrcOp = BM->addUnaryInst(OpBitcast, SrcTy, SrcOp, BB);
53455342
} else if (FPDesc.SrcEncoding != FPEncodingWrap::Integer ||
5346-
(SrcTy->isTypeVector() && !LLVMSrcTy->isVectorTy())) {
5347-
// Create bitcast for FP4, FP8 and packed Int4.
5343+
SrcVecSize > 0) {
5344+
// Create bitcast for FP4, FP8 and packed Int4 (including cases where
5345+
// both the LLVM and SPIR-V types are vectors but with different
5346+
// sizes, e.g. <2 x i8> repacked as <4 x Int4>).
53485347
SrcOp = BM->addUnaryInst(OpBitcast, SrcTy, SrcOp, BB);
53495348
}
53505349
}
5350+
unsigned DstVecSize = 0;
53515351
if (!DstTy) {
53525352
// Dst type is 'mini' float or int4.
5353-
unsigned DstVecSize = 0;
53545353
DstTy = processMiniFPOrInt4Type(LLVMDstTy, FPDesc.DstEncoding,
53555354
GetScalarTy, BM, DstVecSize);
53565355

@@ -5378,10 +5377,9 @@ SPIRVValue *LLVMToSPIRVBase::transDirectCallInst(CallInst *CI,
53785377
if (FPDesc.DstEncoding == FPEncodingWrap::IEEE754 ||
53795378
FPDesc.DstEncoding == FPEncodingWrap::BF16)
53805379
return Conv;
5381-
// Originally not-packed integer.
5380+
// Originally not-packed integer (no repacking, or cooperative matrix).
53825381
if (FPDesc.DstEncoding == FPEncodingWrap::Integer &&
5383-
(DstTy->isTypeVector() == LLVMDstTy->isVectorTy() ||
5384-
isLLVMCooperativeMatrixType(LLVMDstTy)))
5382+
(DstVecSize == 0 || isLLVMCooperativeMatrixType(LLVMDstTy)))
53855383
return Conv;
53865384
// Need to adjust types: create bitcast for FP8 and packed Int4.
53875385
SPIRVValue *BitCast =

test/extensions/INTEL/SPV_INTEL_int4/conversions_packed.ll

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,15 @@
55
; 1. from packed Int4 to ... :
66
; a. packed in 32-bit
77
; b. packed in 8-bit
8+
; c. packed in 16-bit
9+
; d. packed in 64-bit
10+
; e. packed in vector of 8-bit integers
811
; 2. to packed Int4 from ... :
912
; a. packed in 32-bit
1013
; b. packed in 8-bit
14+
; c. packed in 16-bit
15+
; d. packed in 64-bit
16+
; e. packed in vector of 8-bit integers
1117

1218
; RUN: llvm-as %s -o %t.bc
1319
; RUN: llvm-spirv %t.bc -o %t.spv --spirv-ext=+SPV_EXT_float8,+SPV_INTEL_int4,+SPV_KHR_bfloat16
@@ -28,29 +34,50 @@
2834
; CHECK-SPIRV-DAG: Name [[#int4_e4m3_8:]] "int4_e4m3_8"
2935
; CHECK-SPIRV-DAG: Name [[#hf16_int4_32:]] "hf16_int4_32"
3036
; CHECK-SPIRV-DAG: Name [[#hf16_int4_8:]] "hf16_int4_8"
37+
; CHECK-SPIRV-DAG: Name [[#int4_e4m3_16:]] "int4_e4m3_16"
38+
; CHECK-SPIRV-DAG: Name [[#int4_e4m3_64:]] "int4_e4m3_64"
39+
; CHECK-SPIRV-DAG: Name [[#int4_e4m3_vec2xi8:]] "int4_e4m3_vec2xi8"
40+
; CHECK-SPIRV-DAG: Name [[#hf16_int4_16:]] "hf16_int4_16"
41+
; CHECK-SPIRV-DAG: Name [[#hf16_int4_64:]] "hf16_int4_64"
42+
; CHECK-SPIRV-DAG: Name [[#hf16_int4_vec2xi8:]] "hf16_int4_vec2xi8"
3143

3244
; CHECK-SPIRV-DAG: TypeInt [[#Int32Ty:]] 32 0
3345
; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Int32Const:]] 1
3446

3547
; CHECK-SPIRV-DAG: TypeInt [[#Int8Ty:]] 8 0
48+
; CHECK-SPIRV-DAG: TypeInt [[#Int16Ty:]] 16 0
49+
; CHECK-SPIRV-DAG: TypeInt [[#Int64Ty:]] 64 0
3650
; CHECK-SPIRV-DAG: TypeVector [[#Int8Vec8Ty:]] [[#Int8Ty]] 8
3751
; CHECK-SPIRV-DAG: TypeVector [[#Int8Vec2Ty:]] [[#Int8Ty]] 2
52+
; CHECK-SPIRV-DAG: TypeVector [[#Int8Vec4Ty:]] [[#Int8Ty]] 4
53+
; CHECK-SPIRV-DAG: TypeVector [[#Int8Vec16Ty:]] [[#Int8Ty]] 16
3854
; CHECK-SPIRV-DAG: Constant [[#Int8Ty]] [[#Int8Const:]] 1
55+
; CHECK-SPIRV-DAG: Constant [[#Int16Ty]] [[#Int16Const:]] 1
56+
; CHECK-SPIRV-DAG: Constant [[#Int64Ty]] [[#Int64Const:]] 1
57+
; CHECK-SPIRV-DAG: ConstantComposite [[#Int8Vec2Ty]] [[#Int8Vec2Const:]] [[#Int8Const]] [[#Int8Const]]
3958

4059
; CHECK-SPIRV-DAG: TypeInt [[#Int4Ty:]] 4 0
4160
; CHECK-SPIRV-DAG: TypeVector [[#Int4Vec8Ty:]] [[#Int4Ty]] 8
4261
; CHECK-SPIRV-DAG: TypeVector [[#Int4Vec2Ty:]] [[#Int4Ty]] 2
62+
; CHECK-SPIRV-DAG: TypeVector [[#Int4Vec4Ty:]] [[#Int4Ty]] 4
63+
; CHECK-SPIRV-DAG: TypeVector [[#Int4Vec16Ty:]] [[#Int4Ty]] 16
4364

4465
; CHECK-SPIRV-DAG: TypeFloat [[#Float8E4M3Ty:]] 8 4214
4566
; CHECK-SPIRV-DAG: TypeVector [[#Float8E4M3Vec8Ty:]] [[#Float8E4M3Ty]] 8
4667
; CHECK-SPIRV-DAG: TypeVector [[#Float8E4M3Vec2Ty:]] [[#Float8E4M3Ty]] 2
68+
; CHECK-SPIRV-DAG: TypeVector [[#Float8E4M3Vec4Ty:]] [[#Float8E4M3Ty]] 4
69+
; CHECK-SPIRV-DAG: TypeVector [[#Float8E4M3Vec16Ty:]] [[#Float8E4M3Ty]] 16
4770

4871
; CHECK-SPIRV-DAG: TypeFloat [[#HFloat16Ty:]] 16 {{$}}
4972
; CHECK-SPIRV-DAG: TypeVector [[#HFloat16Vec8Ty:]] [[#HFloat16Ty]] 8
5073
; CHECK-SPIRV-DAG: TypeVector [[#HFloat16Vec2Ty:]] [[#HFloat16Ty]] 2
74+
; CHECK-SPIRV-DAG: TypeVector [[#HFloat16Vec4Ty:]] [[#HFloat16Ty]] 4
75+
; CHECK-SPIRV-DAG: TypeVector [[#HFloat16Vec16Ty:]] [[#HFloat16Ty]] 16
5176
; CHECK-SPIRV-DAG: Constant [[#HFloat16Ty]] [[#HFloat16Const:]] 15360
5277
; CHECK-SPIRV-DAG: ConstantComposite [[#HFloat16Vec8Ty]] [[#HFloat16Vec8Const:]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]]
5378
; CHECK-SPIRV-DAG: ConstantComposite [[#HFloat16Vec2Ty]] [[#HFloat16Vec2Const:]] [[#HFloat16Const]] [[#HFloat16Const]]
79+
; CHECK-SPIRV-DAG: ConstantComposite [[#HFloat16Vec4Ty]] [[#HFloat16Vec4Const:]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]]
80+
; CHECK-SPIRV-DAG: ConstantComposite [[#HFloat16Vec16Ty]] [[#HFloat16Vec16Const:]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]]
5481

5582
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
5683
target triple = "spir-unknown-unknown"
@@ -136,3 +163,126 @@ entry:
136163
}
137164

138165
declare dso_local spir_func i8 @_Z38__builtin_spirv_ConvertFP16ToInt4INTELc(<2 x half>)
166+
167+
; Packed in 16-bit integer
168+
169+
; CHECK-SPIRV: Function [[#]] [[#int4_e4m3_16]] [[#]]
170+
; CHECK-SPIRV: Bitcast [[#Int4Vec4Ty]] [[#Cast1:]] [[#Int16Const]]
171+
; CHECK-SPIRV: ConvertSToF [[#Float8E4M3Vec4Ty]] [[#Conv:]] [[#Cast1]]
172+
; CHECK-SPIRV: Bitcast [[#Int8Vec4Ty]] [[#Cast2:]] [[#Conv]]
173+
; CHECK-SPIRV: ReturnValue [[#Cast2]]
174+
175+
; CHECK-LLVM-LABEL: int4_e4m3_16
176+
; CHECK-LLVM: %[[#Cast:]] = bitcast i16 1 to <4 x i4>
177+
; CHECK-LLVM: %[[#Call:]] = call <4 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELDv4_i(<4 x i4> %[[#Cast]])
178+
; CHECK-LLVM: ret <4 x i8> %[[#Call]]
179+
180+
define spir_func <4 x i8> @int4_e4m3_16() {
181+
entry:
182+
%0 = call spir_func <4 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELs(i16 1)
183+
ret <4 x i8> %0
184+
}
185+
186+
declare dso_local spir_func <4 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELs(i16)
187+
188+
; Packed in 64-bit integer
189+
190+
; CHECK-SPIRV: Function [[#]] [[#int4_e4m3_64]] [[#]]
191+
; CHECK-SPIRV: Bitcast [[#Int4Vec16Ty]] [[#Cast1:]] [[#Int64Const]]
192+
; CHECK-SPIRV: ConvertSToF [[#Float8E4M3Vec16Ty]] [[#Conv:]] [[#Cast1]]
193+
; CHECK-SPIRV: Bitcast [[#Int8Vec16Ty]] [[#Cast2:]] [[#Conv]]
194+
; CHECK-SPIRV: ReturnValue [[#Cast2]]
195+
196+
; CHECK-LLVM-LABEL: int4_e4m3_64
197+
; CHECK-LLVM: %[[#Cast:]] = bitcast i64 1 to <16 x i4>
198+
; CHECK-LLVM: %[[#Call:]] = call <16 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELDv16_i(<16 x i4> %[[#Cast]])
199+
; CHECK-LLVM: ret <16 x i8> %[[#Call]]
200+
201+
define spir_func <16 x i8> @int4_e4m3_64() {
202+
entry:
203+
%0 = call spir_func <16 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELl(i64 1)
204+
ret <16 x i8> %0
205+
}
206+
207+
declare dso_local spir_func <16 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELl(i64)
208+
209+
; Packed in vector of 8-bit integers
210+
211+
; CHECK-SPIRV: Function [[#]] [[#int4_e4m3_vec2xi8]] [[#]]
212+
; CHECK-SPIRV: Bitcast [[#Int4Vec4Ty]] [[#Cast1:]] [[#Int8Vec2Const]]
213+
; CHECK-SPIRV: ConvertSToF [[#Float8E4M3Vec4Ty]] [[#Conv:]] [[#Cast1]]
214+
; CHECK-SPIRV: Bitcast [[#Int8Vec4Ty]] [[#Cast2:]] [[#Conv]]
215+
; CHECK-SPIRV: ReturnValue [[#Cast2]]
216+
217+
; CHECK-LLVM-LABEL: int4_e4m3_vec2xi8
218+
; CHECK-LLVM: %[[#Cast:]] = bitcast <2 x i8> <i8 1, i8 1> to <4 x i4>
219+
; CHECK-LLVM: %[[#Call:]] = call <4 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELDv4_i(<4 x i4> %[[#Cast]])
220+
; CHECK-LLVM: ret <4 x i8> %[[#Call]]
221+
222+
define spir_func <4 x i8> @int4_e4m3_vec2xi8() {
223+
entry:
224+
%0 = call spir_func <4 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELDv4_i(<2 x i8> <i8 1, i8 1>)
225+
ret <4 x i8> %0
226+
}
227+
228+
declare dso_local spir_func <4 x i8> @_Z38__builtin_spirv_ConvertInt4ToE4M3INTELDv4_i(<2 x i8>)
229+
230+
; To packed in 16-bit integer
231+
232+
; CHECK-SPIRV: Function [[#]] [[#hf16_int4_16]] [[#]]
233+
; CHECK-SPIRV: ConvertFToS [[#Int4Vec4Ty]] [[#Conv:]] [[#HFloat16Vec4Const]]
234+
; CHECK-SPIRV: Bitcast [[#Int16Ty]] [[#Cast2:]] [[#Conv]]
235+
; CHECK-SPIRV: ReturnValue [[#Cast2]]
236+
237+
; CHECK-LLVM-LABEL: hf16_int4_16
238+
; CHECK-LLVM: %[[#Call:]] = call <4 x i4> @_Z38__builtin_spirv_ConvertFP16ToInt4INTELDv4_Dh(<4 x half> <half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00>)
239+
; CHECK-LLVM: %[[#Cast:]] = bitcast <4 x i4> %[[#Call]] to i16
240+
; CHECK-LLVM: ret i16 %[[#Cast]]
241+
242+
define spir_func i16 @hf16_int4_16() {
243+
entry:
244+
%0 = call i16 @_Z38__builtin_spirv_ConvertFP16ToInt4INTELDv4_Dh(<4 x half> <half 1.0, half 1.0, half 1.0, half 1.0>)
245+
ret i16 %0
246+
}
247+
248+
declare dso_local spir_func i16 @_Z38__builtin_spirv_ConvertFP16ToInt4INTELDv4_Dh(<4 x half>)
249+
250+
; To packed in 64-bit integer
251+
252+
; CHECK-SPIRV: Function [[#]] [[#hf16_int4_64]] [[#]]
253+
; CHECK-SPIRV: ConvertFToS [[#Int4Vec16Ty]] [[#Conv:]] [[#HFloat16Vec16Const]]
254+
; CHECK-SPIRV: Bitcast [[#Int64Ty]] [[#Cast2:]] [[#Conv]]
255+
; CHECK-SPIRV: ReturnValue [[#Cast2]]
256+
257+
; CHECK-LLVM-LABEL: hf16_int4_64
258+
; CHECK-LLVM: %[[#Call:]] = call <16 x i4> @_Z38__builtin_spirv_ConvertFP16ToInt4INTELDv16_Dh(<16 x half> <half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00>)
259+
; CHECK-LLVM: %[[#Cast:]] = bitcast <16 x i4> %[[#Call]] to i64
260+
; CHECK-LLVM: ret i64 %[[#Cast]]
261+
262+
define spir_func i64 @hf16_int4_64() {
263+
entry:
264+
%0 = call i64 @_Z38__builtin_spirv_ConvertFP16ToInt4INTELDv16_Dh(<16 x half> <half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0>)
265+
ret i64 %0
266+
}
267+
268+
declare dso_local spir_func i64 @_Z38__builtin_spirv_ConvertFP16ToInt4INTELDv16_Dh(<16 x half>)
269+
270+
; To packed in vector of 8-bit integers
271+
272+
; CHECK-SPIRV: Function [[#]] [[#hf16_int4_vec2xi8]] [[#]]
273+
; CHECK-SPIRV: ConvertFToS [[#Int4Vec4Ty]] [[#Conv:]] [[#HFloat16Vec4Const]]
274+
; CHECK-SPIRV: Bitcast [[#Int8Vec2Ty]] [[#Cast:]] [[#Conv]]
275+
; CHECK-SPIRV: ReturnValue [[#Cast]]
276+
277+
; CHECK-LLVM-LABEL: hf16_int4_vec2xi8
278+
; CHECK-LLVM: %[[#Call:]] = call <4 x i4> @_Z38__builtin_spirv_ConvertFP16ToInt4INTELDv4_Dh(<4 x half> <half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00>)
279+
; CHECK-LLVM: %[[#Cast:]] = bitcast <4 x i4> %[[#Call]] to <2 x i8>
280+
; CHECK-LLVM: ret <2 x i8> %[[#Cast]]
281+
282+
define spir_func <2 x i8> @hf16_int4_vec2xi8() {
283+
entry:
284+
%0 = call <2 x i8> @_Z38__builtin_spirv_ConvertFP16ToInt4INTELKDv4_Dh(<4 x half> <half 1.0, half 1.0, half 1.0, half 1.0>)
285+
ret <2 x i8> %0
286+
}
287+
288+
declare dso_local spir_func <2 x i8> @_Z38__builtin_spirv_ConvertFP16ToInt4INTELKDv4_Dh(<4 x half>)

0 commit comments

Comments
 (0)