Skip to content

Commit 37d613a

Browse files
committed
[Backport to 15] Fix BFloat16 argument type demangling (#3563)
This also includes patrial cherry-pick of fd1ed03 "Workaround for `bfloat16` parameter type demangling"
1 parent f6d1f4b commit 37d613a

3 files changed

Lines changed: 84 additions & 2 deletions

File tree

lib/SPIRV/SPIRVUtil.cpp

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -813,15 +813,34 @@ void getParameterTypes(Function *F, SmallVectorImpl<StructType *> &ArgTys) {
813813
if (HasSret)
814814
++ArgIter;
815815

816+
// "DF<N>b" mangling for bfloat<N> types (e.g. DF16b for bfloat16) is
817+
// recognized by the demangler only starting from LLVM 20. Replace "DF16b"
818+
// in the parameter section with the vendor-extended-type encoding "u6__bf16",
819+
// which all known demangler versions parse correctly as NameType("__bf16").
820+
std::string PatchedName;
821+
StringRef MangledName(F->getName());
822+
if (MangledName.contains("DF16b")) {
823+
PatchedName = MangledName.str();
824+
// Skip "_Z<N><name>" to search only in the parameter section.
825+
const size_t Start = PatchedName.find_first_not_of("0123456789", 2);
826+
size_t Len = 0;
827+
StringRef(PatchedName).substr(2, Start - 2).getAsInteger(10, Len);
828+
size_t Pos = Start + Len;
829+
while ((Pos = PatchedName.find("DF16b", Pos)) != std::string::npos) {
830+
PatchedName.replace(Pos, 5, "u6__bf16");
831+
Pos += 8;
832+
}
833+
MangledName = PatchedName;
834+
}
835+
816836
// Demangle the function arguments. If we get an input name of
817837
// "_Z12write_imagei20ocl_image1d_array_woDv2_iiDv4_i", then we expect
818838
// that Demangler.getFunctionParameters will return
819839
// "(ocl_image1d_array_wo, int __vector(2), int, int __vector(4))" (in other
820840
// words, the stuff between the parentheses if you ran C++ filt, including
821841
// the parentheses itself).
822842
ItaniumPartialDemangler Demangler;
823-
std::string MangledName = F->getName().str();
824-
if (Demangler.partialDemangle(MangledName.c_str()))
843+
if (Demangler.partialDemangle(MangledName.data()))
825844
return;
826845
char *Buf = nullptr;
827846
size_t BufLen = 0;
@@ -865,6 +884,8 @@ void getParameterTypes(Function *F, SmallVectorImpl<StructType *> &ArgTys) {
865884
Pointee = getOrCreateOpaqueStructType(M, StructName);
866885
} else if (MangledStructName.startswith("opencl.")) {
867886
Pointee = getOrCreateOpaqueStructType(M, MangledStructName);
887+
} else if (MangledStructName == "__bf16") {
888+
Pointee = Type::getBFloatTy(M->getContext());
868889
}
869890
} else if (!Arg.contains(' ') && Arg.startswith("ocl_")) {
870891
std::string StructName = demangleBuiltinOpenCLTypeName(Arg);

test/extensions/INTEL/SPV_INTEL_bfloat16/bfloat16_math.ll renamed to test/extensions/INTEL/SPV_INTEL_bfloat16_arithmetic/bfloat16_math.ll

File renamed without changes.
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc -spirv-text -o %t.spt --spirv-ext=+SPV_KHR_bfloat16,+SPV_INTEL_bfloat16_arithmetic
3+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
4+
; RUN: llvm-spirv -to-binary %t.spt -o %t.spv
5+
; TODO: reenable the validation once the BFloat16 type is supported in ExtInst.
6+
; Currently fails with: ExtInst doesn't support BFloat16 type.
7+
; RUNx: spirv-val %t.spv
8+
; RUN: llvm-spirv -r %t.spv -o - | llvm-dis -o %t.rev.ll
9+
; RUN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM
10+
; RUN: llvm-spirv -r %t.spv --spirv-target-env=SPV-IR -o - | llvm-dis -o %t.rev.ll
11+
; RUN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-SPV-IR
12+
13+
14+
; CHECK-SPIRV: Capability BFloat16TypeKHR
15+
; CHECK-SPIRV: Extension "SPV_KHR_bfloat16"
16+
; CHECK-SPIRV: TypeFloat [[#BFLOAT:]] 16 0
17+
; CHECK-SPIRV: TypeVector [[#VEC:]] [[#BFLOAT]] 2
18+
; CHECK-SPIRV: TypePointer [[#PTR:]] [[#]] [[#BFLOAT]]
19+
20+
; CHECK-LABEL: Function
21+
; CHECK-SPIRV: FunctionParameter [[#PTR]] [[#PTR_ARG:]]
22+
; CHECK-SPIRV: ExtInst [[#VEC]] [[#]] [[#]] vloadn [[#]] [[#PTR_ARG]] 2
23+
24+
; CHECK-LABEL: Function
25+
; CHECK-SPIRV: FunctionParameter [[#VEC]] [[#DATA_ARG:]]
26+
; CHECK-SPIRV: FunctionParameter [[#PTR]] [[#PTR_ARG2:]]
27+
; CHECK-SPIRV: ExtInst [[#]] [[#]] [[#]] vstoren [[#DATA_ARG]] [[#]] [[#PTR_ARG2]]
28+
29+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
30+
target triple = "spir64-unknown-unknown"
31+
32+
; CHECK-LLVM: call spir_func <2 x bfloat> @_Z6vload2mPU3AS1KDF16b(i64 %offset, bfloat addrspace(1)* %ptr)
33+
; CHECK-LLVM: call spir_func void @_Z7vstore2Dv2_DF16bmPU3AS1DF16b(<2 x bfloat> %data, i64 %offset, bfloat addrspace(1)* %ptr)
34+
35+
; CHECK-SPV-IR: call spir_func <2 x bfloat> @_Z26__spirv_ocl_vloadn_RDF16b2mPU3AS1KDF16bi(i64 %offset, bfloat addrspace(1)* %ptr, i32 2)
36+
; CHECK-SPV-IR: call spir_func void @_Z19__spirv_ocl_vstorenDv2_DF16bmPU3AS1DF16b(<2 x bfloat> %data, i64 %offset, bfloat addrspace(1)* %ptr)
37+
38+
define spir_func <2 x bfloat> @test_spirv_ocl_vload2(i64 %offset, ptr addrspace(1) %ptr) {
39+
%result = call spir_func <2 x bfloat> @_Z26__spirv_ocl_vloadn__RDF16blPU3AS1DF16bi(i64 %offset, bfloat addrspace(1)* %ptr, i32 2)
40+
ret <2 x bfloat> %result
41+
}
42+
43+
define spir_func void @test_spirv_ocl_vstore2(<2 x bfloat> %data, i64 %offset, ptr addrspace(1) %ptr) {
44+
call spir_func void @_Z19__spirv_ocl_vstorenDv2_DF16blPU3AS1DF16b(<2 x bfloat> %data, i64 %offset, bfloat addrspace(1)* %ptr)
45+
ret void
46+
}
47+
48+
declare spir_func <2 x bfloat> @_Z26__spirv_ocl_vloadn__RDF16blPU3AS1DF16bi(i64, bfloat addrspace(1)*, i32)
49+
declare spir_func void @_Z19__spirv_ocl_vstorenDv2_DF16blPU3AS1DF16b(<2 x bfloat>, i64, bfloat addrspace(1)*)
50+
51+
!opencl.enable.FP_CONTRACT = !{}
52+
!opencl.spir.version = !{!0}
53+
!opencl.ocl.version = !{!1}
54+
!opencl.used.extensions = !{!2}
55+
!opencl.used.optional.core.features = !{!3}
56+
!opencl.compiler.options = !{!3}
57+
58+
!0 = !{i32 1, i32 2}
59+
!1 = !{i32 2, i32 0}
60+
!2 = !{!"cl_khr_fp16"}
61+
!3 = !{}

0 commit comments

Comments
 (0)