55#include < iostream>
66#include < numeric>
77
8- #if defined( ORO_PP_LOAD_FROM_STRING )
9-
8+ // if ORO_PP_LOAD_FROM_STRING && ORO_PRECOMPILED -> we load the precompiled/baked kernels.
9+ // if ORO_PP_LOAD_FROM_STRING && NOT ORO_PRECOMPILED -> we load the baked source code kernels (from Kernels.h / KernelArgs.h)
10+ #if !defined( ORO_PRECOMPILED ) && defined( ORO_PP_LOAD_FROM_STRING )
1011// Note: the include order must be in this particular form.
1112// clang-format off
1213#include < ParallelPrimitives/cache/Kernels.h>
1314#include < ParallelPrimitives/cache/KernelArgs.h>
1415// clang-format on
16+ #else
17+ // if Kernels.h / KernelArgs.h are not included, declare nullptr strings
18+ static const char * hip_RadixSortKernels = nullptr ;
19+ namespace hip
20+ {
21+ static const char ** RadixSortKernelsArgs = nullptr ;
22+ static const char ** RadixSortKernelsIncludes = nullptr ;
23+ } // namespace hip
1524#endif
1625
1726#if defined( __GNUC__ )
1827#include < dlfcn.h>
1928#endif
2029
30+ #if defined( ORO_PRECOMPILED ) && defined( ORO_PP_LOAD_FROM_STRING )
31+ #include < ParallelPrimitives/cache/oro_compiled_kernels.h> // generate this header with 'convert_binary_to_array.py'
32+ #else
33+ const unsigned char oro_compiled_kernels_h[] = " " ;
34+ const size_t oro_compiled_kernels_h_size = 0 ;
35+ #endif
36+
2137constexpr uint64_t div_round_up64 ( uint64_t val, uint64_t divisor ) noexcept { return ( val + divisor - 1 ) / divisor; }
2238constexpr uint64_t next_multiple64 ( uint64_t val, uint64_t divisor ) noexcept { return div_round_up64 ( val, divisor ) * divisor; }
2339
2440namespace
2541{
42+
43+ // if those 2 preprocessors are enabled, this activates the 'usePrecompiledAndBakedKernel' mode.
44+ #if defined( ORO_PRECOMPILED ) && defined( ORO_PP_LOAD_FROM_STRING )
45+
46+ // this flag means that we bake the precompiled kernels
47+ constexpr auto usePrecompiledAndBakedKernel = true ;
48+
49+ constexpr auto useBitCode = false ;
50+ constexpr auto useBakeKernel = false ;
51+
52+ #else
53+
54+ constexpr auto usePrecompiledAndBakedKernel = false ;
55+
2656#if defined( ORO_PRECOMPILED )
27- constexpr auto useBitCode = true ;
57+ constexpr auto useBitCode = true ; // this flag means we use the bitcode file
2858#else
2959constexpr auto useBitCode = false ;
3060#endif
3161
3262#if defined( ORO_PP_LOAD_FROM_STRING )
33- constexpr auto useBakeKernel = true ;
63+ constexpr auto useBakeKernel = true ; // this flag means we use the HIP source code embeded in the binary ( as a string )
3464#else
3565constexpr auto useBakeKernel = false ;
36- static const char * hip_RadixSortKernels = nullptr ;
37- namespace hip
38- {
39- static const char ** RadixSortKernelsArgs = nullptr ;
40- static const char ** RadixSortKernelsIncludes = nullptr ;
41- } // namespace hip
66+ #endif
67+
4268#endif
4369
4470static_assert ( !( useBitCode && useBakeKernel ), " useBitCode and useBakeKernel cannot coexist" );
@@ -138,11 +164,15 @@ void RadixSort::compileKernels( const std::string& kernelPath, const std::string
138164
139165 for ( const auto & record : records )
140166 {
141- #if defined( ORO_PP_LOAD_FROM_STRING )
142- oroFunctions[record.kernelType ] = oroutils.getFunctionFromString ( device, hip_RadixSortKernels, currentKernelPath.c_str (), record.kernelName .c_str (), &opts, 1 , hip::RadixSortKernelsArgs, hip::RadixSortKernelsIncludes );
143- #else
144-
145- if constexpr ( useBitCode )
167+ if constexpr ( usePrecompiledAndBakedKernel )
168+ {
169+ oroFunctions[record.kernelType ] = m_oroutils.getFunctionFromPrecompiledBinary_asData ( oro_compiled_kernels_h, oro_compiled_kernels_h_size, record.kernelName .c_str () );
170+ }
171+ else if constexpr ( useBakeKernel )
172+ {
173+ oroFunctions[record.kernelType ] = m_oroutils.getFunctionFromString ( m_device, hip_RadixSortKernels, currentKernelPath.c_str (), record.kernelName .c_str (), &opts, 1 , hip::RadixSortKernelsArgs, hip::RadixSortKernelsIncludes );
174+ }
175+ else if constexpr ( useBitCode )
146176 {
147177 oroFunctions[record.kernelType ] = m_oroutils.getFunctionFromPrecompiledBinary ( binaryPath.c_str (), record.kernelName .c_str () );
148178 }
@@ -151,7 +181,6 @@ void RadixSort::compileKernels( const std::string& kernelPath, const std::string
151181 oroFunctions[record.kernelType ] = m_oroutils.getFunctionFromFile ( m_device, currentKernelPath.c_str (), record.kernelName .c_str (), &opts );
152182 }
153183
154- #endif
155184 if ( m_flags == Flag::LOG )
156185 {
157186 printKernelInfo ( record.kernelName , oroFunctions[record.kernelType ] );
0 commit comments