Skip to content

Commit f676199

Browse files
authored
[Backport to 22] Extend INT4/FP4 packed conversions for i16, i64, and vector packed inputs (#3675) (#3706)
1 parent 8ad4c2e commit f676199

3 files changed

Lines changed: 317 additions & 19 deletions

File tree

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5559,21 +5559,18 @@ processMiniFPOrInt4Type(Type *LLVMTy, FPEncodingWrap Encoding,
55595559
unsigned TyWidth = cast<IntegerType>(ScalarTy)->getBitWidth();
55605560
unsigned VecSize = 0;
55615561

5562-
if (TyWidth == 32) {
5563-
// Int4 or FP4 packed in 32-bit integer, change type and vector size.
5564-
assert((Encoding == FPEncodingWrap::E2M1 ||
5565-
Encoding == FPEncodingWrap::Integer) &&
5566-
"Unknown FP encoding");
5562+
bool IsPacked =
5563+
Encoding == FPEncodingWrap::E2M1 || Encoding == FPEncodingWrap::Integer;
5564+
if (IsPacked &&
5565+
(TyWidth == 8 || TyWidth == 16 || TyWidth == 32 || TyWidth == 64)) {
5566+
// Int4 or FP4 packed in an integer: each N-bit integer holds N/4 values.
55675567
assert(!isLLVMCooperativeMatrixType(LLVMTy) &&
55685568
"FP4 and Int4 matrices must not be packed");
5569-
VecSize = 8;
5570-
TyWidth = 4;
5571-
} else if (TyWidth == 8 && (Encoding == FPEncodingWrap::E2M1 ||
5572-
Encoding == FPEncodingWrap::Integer)) {
5573-
assert(!isLLVMCooperativeMatrixType(LLVMTy) &&
5574-
"FP4 and Int4 matrices must not be packed");
5575-
// Int4 or FP4 packed in 8-bit integer, change type and vector size.
5576-
VecSize = 2;
5569+
unsigned OuterVecLen =
5570+
LLVMTy->isVectorTy()
5571+
? cast<VectorType>(LLVMTy)->getElementCount().getFixedValue()
5572+
: 1;
5573+
VecSize = (TyWidth / 4) * OuterVecLen;
55775574
TyWidth = 4;
55785575
} else {
55795576
if (LLVMTy->isVectorTy())
@@ -5669,14 +5666,16 @@ SPIRVValue *LLVMToSPIRVBase::transDirectCallInst(CallInst *CI,
56695666
->getArgs());
56705667
SrcOp = BM->addUnaryInst(OpBitcast, SrcTy, SrcOp, BB);
56715668
} else if (FPDesc.SrcEncoding != FPEncodingWrap::Integer ||
5672-
(SrcTy->isTypeVector() && !LLVMSrcTy->isVectorTy())) {
5673-
// Create bitcast for FP4, FP8 and packed Int4.
5669+
SrcVecSize > 0) {
5670+
// Create bitcast for FP4, FP8 and packed Int4 (including cases where
5671+
// both the LLVM and SPIR-V types are vectors but with different
5672+
// sizes, e.g. <2 x i8> repacked as <4 x Int4>).
56745673
SrcOp = BM->addUnaryInst(OpBitcast, SrcTy, SrcOp, BB);
56755674
}
56765675
}
5676+
unsigned DstVecSize = 0;
56775677
if (!DstTy) {
56785678
// Dst type is 'mini' float or int4.
5679-
unsigned DstVecSize = 0;
56805679
DstTy = processMiniFPOrInt4Type(LLVMDstTy, FPDesc.DstEncoding,
56815680
GetScalarTy, BM, DstVecSize);
56825681

@@ -5704,10 +5703,9 @@ SPIRVValue *LLVMToSPIRVBase::transDirectCallInst(CallInst *CI,
57045703
if (FPDesc.DstEncoding == FPEncodingWrap::IEEE754 ||
57055704
FPDesc.DstEncoding == FPEncodingWrap::BF16)
57065705
return Conv;
5707-
// Originally not-packed integer.
5706+
// Originally not-packed integer (no repacking, or cooperative matrix).
57085707
if (FPDesc.DstEncoding == FPEncodingWrap::Integer &&
5709-
(DstTy->isTypeVector() == LLVMDstTy->isVectorTy() ||
5710-
isLLVMCooperativeMatrixType(LLVMDstTy)))
5708+
(DstVecSize == 0 || isLLVMCooperativeMatrixType(LLVMDstTy)))
57115709
return Conv;
57125710
// Need to adjust types: create bitcast for FP8 and packed Int4.
57135711
SPIRVValue *BitCast =

test/extensions/INTEL/SPV_INTEL_float4/conversions_packed.ll

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,15 @@
55
; 1. from packed FP4 to ... :
66
; a. packed in 32-bit
77
; b. packed in 8-bit
8+
; c. packed in 16-bit
9+
; d. packed in 64-bit
10+
; e. packed in vector of 8-bit integers
811
; 2. to packed FP4 from ... :
912
; a. packed in 32-bit
1013
; b. packed in 8-bit
14+
; c. packed in 16-bit
15+
; d. packed in 64-bit
16+
; e. packed in vector of 8-bit integers
1117

1218
; RUN: llvm-spirv %s -o %t.spv --spirv-ext=+SPV_EXT_float8,+SPV_INTEL_float4,+SPV_INTEL_int4,+SPV_KHR_bfloat16
1319
; RUN: llvm-spirv %t.spv -o %t.spt --to-text
@@ -25,29 +31,50 @@
2531
; CHECK-SPIRV-DAG: Name [[#fp4e2m1_hf8_8:]] "fp4e2m1_hf8_8"
2632
; CHECK-SPIRV-DAG: Name [[#hf16_fp4e2m1_32:]] "hf16_fp4e2m1_32"
2733
; CHECK-SPIRV-DAG: Name [[#hf16_fp4e2m1_8:]] "hf16_fp4e2m1_8"
34+
; CHECK-SPIRV-DAG: Name [[#fp4e2m1_hf8_16:]] "fp4e2m1_hf8_16"
35+
; CHECK-SPIRV-DAG: Name [[#fp4e2m1_hf8_64:]] "fp4e2m1_hf8_64"
36+
; CHECK-SPIRV-DAG: Name [[#fp4e2m1_hf8_vec2xi8:]] "fp4e2m1_hf8_vec2xi8"
37+
; CHECK-SPIRV-DAG: Name [[#hf16_fp4e2m1_16:]] "hf16_fp4e2m1_16"
38+
; CHECK-SPIRV-DAG: Name [[#hf16_fp4e2m1_64:]] "hf16_fp4e2m1_64"
39+
; CHECK-SPIRV-DAG: Name [[#hf16_fp4e2m1_vec2xi8:]] "hf16_fp4e2m1_vec2xi8"
2840

2941
; CHECK-SPIRV-DAG: TypeInt [[#Int32Ty:]] 32 0
3042
; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Int32Const:]] 1
3143

3244
; CHECK-SPIRV-DAG: TypeInt [[#Int8Ty:]] 8 0
45+
; CHECK-SPIRV-DAG: TypeInt [[#Int16Ty:]] 16 0
46+
; CHECK-SPIRV-DAG: TypeInt [[#Int64Ty:]] 64 0
3347
; CHECK-SPIRV-DAG: TypeVector [[#Int8Vec8Ty:]] [[#Int8Ty]] 8
3448
; CHECK-SPIRV-DAG: TypeVector [[#Int8Vec2Ty:]] [[#Int8Ty]] 2
49+
; CHECK-SPIRV-DAG: TypeVector [[#Int8Vec4Ty:]] [[#Int8Ty]] 4
50+
; CHECK-SPIRV-DAG: TypeVector [[#Int8Vec16Ty:]] [[#Int8Ty]] 16
3551
; CHECK-SPIRV-DAG: Constant [[#Int8Ty]] [[#Int8Const:]] 1
52+
; CHECK-SPIRV-DAG: Constant [[#Int16Ty]] [[#Int16Const:]] 1
53+
; CHECK-SPIRV-DAG: Constant [[#Int64Ty]] [[#Int64Const:]] 1
54+
; CHECK-SPIRV-DAG: ConstantComposite [[#Int8Vec2Ty]] [[#Int8Vec2Const:]] [[#Int8Const]] [[#Int8Const]]
3655

3756
; CHECK-SPIRV-DAG: TypeFloat [[#E2M1Ty:]] 4 6214
3857
; CHECK-SPIRV-DAG: TypeVector [[#E2M1Vec8Ty:]] [[#E2M1Ty]] 8
3958
; CHECK-SPIRV-DAG: TypeVector [[#E2M1Vec2Ty:]] [[#E2M1Ty]] 2
59+
; CHECK-SPIRV-DAG: TypeVector [[#E2M1Vec4Ty:]] [[#E2M1Ty]] 4
60+
; CHECK-SPIRV-DAG: TypeVector [[#E2M1Vec16Ty:]] [[#E2M1Ty]] 16
4061

4162
; CHECK-SPIRV-DAG: TypeFloat [[#Float8E4M3Ty:]] 8 4214
4263
; CHECK-SPIRV-DAG: TypeVector [[#Float8E4M3Vec8Ty:]] [[#Float8E4M3Ty]] 8
4364
; CHECK-SPIRV-DAG: TypeVector [[#Float8E4M3Vec2Ty:]] [[#Float8E4M3Ty]] 2
65+
; CHECK-SPIRV-DAG: TypeVector [[#Float8E4M3Vec4Ty:]] [[#Float8E4M3Ty]] 4
66+
; CHECK-SPIRV-DAG: TypeVector [[#Float8E4M3Vec16Ty:]] [[#Float8E4M3Ty]] 16
4467

4568
; CHECK-SPIRV-DAG: TypeFloat [[#HFloat16Ty:]] 16 {{$}}
4669
; CHECK-SPIRV-DAG: TypeVector [[#HFloat16Vec8Ty:]] [[#HFloat16Ty]] 8
4770
; CHECK-SPIRV-DAG: TypeVector [[#HFloat16Vec2Ty:]] [[#HFloat16Ty]] 2
71+
; CHECK-SPIRV-DAG: TypeVector [[#HFloat16Vec4Ty:]] [[#HFloat16Ty]] 4
72+
; CHECK-SPIRV-DAG: TypeVector [[#HFloat16Vec16Ty:]] [[#HFloat16Ty]] 16
4873
; CHECK-SPIRV-DAG: Constant [[#HFloat16Ty]] [[#HFloat16Const:]] 15360
4974
; CHECK-SPIRV-DAG: ConstantComposite [[#HFloat16Vec8Ty]] [[#HFloat16Vec8Const:]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]]
5075
; CHECK-SPIRV-DAG: ConstantComposite [[#HFloat16Vec2Ty]] [[#HFloat16Vec2Const:]] [[#HFloat16Const]] [[#HFloat16Const]]
76+
; CHECK-SPIRV-DAG: ConstantComposite [[#HFloat16Vec4Ty]] [[#HFloat16Vec4Const:]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]]
77+
; CHECK-SPIRV-DAG: ConstantComposite [[#HFloat16Vec16Ty]] [[#HFloat16Vec16Const:]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]] [[#HFloat16Const]]
5178

5279
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
5380
target triple = "spir-unknown-unknown"
@@ -129,3 +156,126 @@ entry:
129156
}
130157

131158
declare dso_local spir_func i8 @_Z38__builtin_spirv_ConvertFP16ToE2M1INTELDv2_Dh(<2 x half>)
159+
160+
; Packed in 16-bit integer
161+
162+
; CHECK-SPIRV: Function [[#]] [[#fp4e2m1_hf8_16]] [[#]]
163+
; CHECK-SPIRV: Bitcast [[#E2M1Vec4Ty]] [[#Cast1:]] [[#Int16Const]]
164+
; CHECK-SPIRV: FConvert [[#Float8E4M3Vec4Ty]] [[#Conv:]] [[#Cast1]]
165+
; CHECK-SPIRV: Bitcast [[#Int8Vec4Ty]] [[#Cast2:]] [[#Conv]]
166+
; CHECK-SPIRV: ReturnValue [[#Cast2]]
167+
168+
; CHECK-LLVM-LABEL: fp4e2m1_hf8_16
169+
; CHECK-LLVM: %[[#Cast:]] = bitcast i16 1 to <4 x i4>
170+
; CHECK-LLVM: %[[#Call:]] = call <4 x i8> @_Z38__builtin_spirv_ConvertE2M1ToE4M3INTELDv4_i(<4 x i4> %[[#Cast]])
171+
; CHECK-LLVM: ret <4 x i8> %[[#Call]]
172+
173+
define spir_func <4 x i8> @fp4e2m1_hf8_16() {
174+
entry:
175+
%0 = call spir_func <4 x i8> @_Z38__builtin_spirv_ConvertE2M1ToE4M3INTELs(i16 1)
176+
ret <4 x i8> %0
177+
}
178+
179+
declare dso_local spir_func <4 x i8> @_Z38__builtin_spirv_ConvertE2M1ToE4M3INTELs(i16)
180+
181+
; Packed in 64-bit integer
182+
183+
; CHECK-SPIRV: Function [[#]] [[#fp4e2m1_hf8_64]] [[#]]
184+
; CHECK-SPIRV: Bitcast [[#E2M1Vec16Ty]] [[#Cast1:]] [[#Int64Const]]
185+
; CHECK-SPIRV: FConvert [[#Float8E4M3Vec16Ty]] [[#Conv:]] [[#Cast1]]
186+
; CHECK-SPIRV: Bitcast [[#Int8Vec16Ty]] [[#Cast2:]] [[#Conv]]
187+
; CHECK-SPIRV: ReturnValue [[#Cast2]]
188+
189+
; CHECK-LLVM-LABEL: fp4e2m1_hf8_64
190+
; CHECK-LLVM: %[[#Cast:]] = bitcast i64 1 to <16 x i4>
191+
; CHECK-LLVM: %[[#Call:]] = call <16 x i8> @_Z38__builtin_spirv_ConvertE2M1ToE4M3INTELDv16_i(<16 x i4> %[[#Cast]])
192+
; CHECK-LLVM: ret <16 x i8> %[[#Call]]
193+
194+
define spir_func <16 x i8> @fp4e2m1_hf8_64() {
195+
entry:
196+
%0 = call spir_func <16 x i8> @_Z38__builtin_spirv_ConvertE2M1ToE4M3INTELl(i64 1)
197+
ret <16 x i8> %0
198+
}
199+
200+
declare dso_local spir_func <16 x i8> @_Z38__builtin_spirv_ConvertE2M1ToE4M3INTELl(i64)
201+
202+
; Packed in vector of 8-bit integers
203+
204+
; CHECK-SPIRV: Function [[#]] [[#fp4e2m1_hf8_vec2xi8]] [[#]]
205+
; CHECK-SPIRV: Bitcast [[#E2M1Vec4Ty]] [[#Cast1:]] [[#Int8Vec2Const]]
206+
; CHECK-SPIRV: FConvert [[#Float8E4M3Vec4Ty]] [[#Conv:]] [[#Cast1]]
207+
; CHECK-SPIRV: Bitcast [[#Int8Vec4Ty]] [[#Cast2:]] [[#Conv]]
208+
; CHECK-SPIRV: ReturnValue [[#Cast2]]
209+
210+
; CHECK-LLVM-LABEL: fp4e2m1_hf8_vec2xi8
211+
; CHECK-LLVM: %[[#Cast:]] = bitcast <2 x i8> splat (i8 1) to <4 x i4>
212+
; CHECK-LLVM: %[[#Call:]] = call <4 x i8> @_Z38__builtin_spirv_ConvertE2M1ToE4M3INTELDv4_i(<4 x i4> %[[#Cast]])
213+
; CHECK-LLVM: ret <4 x i8> %[[#Call]]
214+
215+
define spir_func <4 x i8> @fp4e2m1_hf8_vec2xi8() {
216+
entry:
217+
%0 = call spir_func <4 x i8> @_Z38__builtin_spirv_ConvertE2M1ToE4M3INTELDv4_i(<2 x i8> <i8 1, i8 1>)
218+
ret <4 x i8> %0
219+
}
220+
221+
declare dso_local spir_func <4 x i8> @_Z38__builtin_spirv_ConvertE2M1ToE4M3INTELDv4_i(<2 x i8>)
222+
223+
; To packed in 16-bit integer
224+
225+
; CHECK-SPIRV: Function [[#]] [[#hf16_fp4e2m1_16]] [[#]]
226+
; CHECK-SPIRV: FConvert [[#E2M1Vec4Ty]] [[#Conv:]] [[#HFloat16Vec4Const]]
227+
; CHECK-SPIRV: Bitcast [[#Int16Ty]] [[#Cast2:]] [[#Conv]]
228+
; CHECK-SPIRV: ReturnValue [[#Cast2]]
229+
230+
; CHECK-LLVM-LABEL: hf16_fp4e2m1_16
231+
; CHECK-LLVM: %[[#Call:]] = call <4 x i4> @_Z38__builtin_spirv_ConvertFP16ToE2M1INTELDv4_Dh(<4 x half> splat (half 0xH3C00))
232+
; CHECK-LLVM: %[[#Cast:]] = bitcast <4 x i4> %[[#Call]] to i16
233+
; CHECK-LLVM: ret i16 %[[#Cast]]
234+
235+
define spir_func i16 @hf16_fp4e2m1_16() {
236+
entry:
237+
%0 = call i16 @_Z38__builtin_spirv_ConvertFP16ToE2M1INTELDv4_Dh(<4 x half> <half 1.0, half 1.0, half 1.0, half 1.0>)
238+
ret i16 %0
239+
}
240+
241+
declare dso_local spir_func i16 @_Z38__builtin_spirv_ConvertFP16ToE2M1INTELDv4_Dh(<4 x half>)
242+
243+
; To packed in 64-bit integer
244+
245+
; CHECK-SPIRV: Function [[#]] [[#hf16_fp4e2m1_64]] [[#]]
246+
; CHECK-SPIRV: FConvert [[#E2M1Vec16Ty]] [[#Conv:]] [[#HFloat16Vec16Const]]
247+
; CHECK-SPIRV: Bitcast [[#Int64Ty]] [[#Cast2:]] [[#Conv]]
248+
; CHECK-SPIRV: ReturnValue [[#Cast2]]
249+
250+
; CHECK-LLVM-LABEL: hf16_fp4e2m1_64
251+
; CHECK-LLVM: %[[#Call:]] = call <16 x i4> @_Z38__builtin_spirv_ConvertFP16ToE2M1INTELDv16_Dh(<16 x half> splat (half 0xH3C00))
252+
; CHECK-LLVM: %[[#Cast:]] = bitcast <16 x i4> %[[#Call]] to i64
253+
; CHECK-LLVM: ret i64 %[[#Cast]]
254+
255+
define spir_func i64 @hf16_fp4e2m1_64() {
256+
entry:
257+
%0 = call i64 @_Z38__builtin_spirv_ConvertFP16ToE2M1INTELDv16_Dh(<16 x half> <half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0, half 1.0>)
258+
ret i64 %0
259+
}
260+
261+
declare dso_local spir_func i64 @_Z38__builtin_spirv_ConvertFP16ToE2M1INTELDv16_Dh(<16 x half>)
262+
263+
; To packed in vector of 8-bit integers
264+
265+
; CHECK-SPIRV: Function [[#]] [[#hf16_fp4e2m1_vec2xi8]] [[#]]
266+
; CHECK-SPIRV: FConvert [[#E2M1Vec4Ty]] [[#Conv:]] [[#HFloat16Vec4Const]]
267+
; CHECK-SPIRV: Bitcast [[#Int8Vec2Ty]] [[#Cast:]] [[#Conv]]
268+
; CHECK-SPIRV: ReturnValue [[#Cast]]
269+
270+
; CHECK-LLVM-LABEL: hf16_fp4e2m1_vec2xi8
271+
; CHECK-LLVM: %[[#Call:]] = call <4 x i4> @_Z38__builtin_spirv_ConvertFP16ToE2M1INTELDv4_Dh(<4 x half> splat (half 0xH3C00))
272+
; CHECK-LLVM: %[[#Cast:]] = bitcast <4 x i4> %[[#Call]] to <2 x i8>
273+
; CHECK-LLVM: ret <2 x i8> %[[#Cast]]
274+
275+
define spir_func <2 x i8> @hf16_fp4e2m1_vec2xi8() {
276+
entry:
277+
%0 = call <2 x i8> @_Z38__builtin_spirv_ConvertFP16ToE2M1INTELKDv4_Dh(<4 x half> <half 1.0, half 1.0, half 1.0, half 1.0>)
278+
ret <2 x i8> %0
279+
}
280+
281+
declare dso_local spir_func <2 x i8> @_Z38__builtin_spirv_ConvertFP16ToE2M1INTELKDv4_Dh(<4 x half>)

0 commit comments

Comments
 (0)