@@ -42,47 +42,50 @@ namespace hip
4242{
4343static const char ** RadixSortKernelsArgs = nullptr ;
4444static const char ** RadixSortKernelsIncludes = nullptr ;
45- }
45+ } // namespace hip
4646#endif
4747
4848#if defined( __GNUC__ )
4949#include < dlfcn.h>
5050#endif
5151
52- #if defined( ORO_PRECOMPILED ) && defined( ORO_PP_LOAD_FROM_STRING )
52+ #if defined( ORO_PRECOMPILED ) && defined( ORO_PP_LOAD_FROM_STRING )
5353#include < ParallelPrimitives/cache/oro_compiled_kernels.h> // generate this header with 'convert_binary_to_array.py'
5454#else
5555const unsigned char oro_compiled_kernels_h[] = " " ;
5656const size_t oro_compiled_kernels_h_size = 0 ;
5757#endif
5858
59+ constexpr uint64_t div_round_up64 ( uint64_t val, uint64_t divisor ) noexcept { return ( val + divisor - 1 ) / divisor; }
60+ constexpr uint64_t next_multiple64 ( uint64_t val, uint64_t divisor ) noexcept { return div_round_up64 ( val, divisor ) * divisor; }
61+
5962namespace
6063{
6164
6265// if those 2 preprocessors are enabled, this activates the 'usePrecompiledAndBakedKernel' mode.
63- #if defined( ORO_PRECOMPILED ) && defined( ORO_PP_LOAD_FROM_STRING )
64-
65- // this flag means that we bake the precompiled kernels
66- constexpr auto usePrecompiledAndBakedKernel = true ;
66+ #if defined( ORO_PRECOMPILED ) && defined( ORO_PP_LOAD_FROM_STRING )
67+
68+ // this flag means that we bake the precompiled kernels
69+ constexpr auto usePrecompiledAndBakedKernel = true ;
6770
68- constexpr auto useBitCode = false ;
69- constexpr auto useBakeKernel = false ;
71+ constexpr auto useBitCode = false ;
72+ constexpr auto useBakeKernel = false ;
7073
7174#else
7275
73- constexpr auto usePrecompiledAndBakedKernel = false ;
76+ constexpr auto usePrecompiledAndBakedKernel = false ;
7477
75- #if defined( ORO_PRECOMPILED )
76- constexpr auto useBitCode = true ; // this flag means we use the bitcode file
77- #else
78- constexpr auto useBitCode = false ;
79- #endif
78+ #if defined( ORO_PRECOMPILED )
79+ constexpr auto useBitCode = true ; // this flag means we use the bitcode file
80+ #else
81+ constexpr auto useBitCode = false ;
82+ #endif
8083
81- #if defined( ORO_PP_LOAD_FROM_STRING )
82- constexpr auto useBakeKernel = true ; // this flag means we use the HIP source code embeded in the binary ( as a string )
83- #else
84- constexpr auto useBakeKernel = false ;
85- #endif
84+ #if defined( ORO_PP_LOAD_FROM_STRING )
85+ constexpr auto useBakeKernel = true ; // this flag means we use the HIP source code embeded in the binary ( as a string )
86+ #else
87+ constexpr auto useBakeKernel = false ;
88+ #endif
8689
8790#endif
8891
@@ -124,23 +127,6 @@ RadixSort::RadixSort( oroDevice device, OrochiUtils& oroutils, oroStream stream,
124127 configure ( kernelPath, includeDir, stream );
125128}
126129
127- void RadixSort::exclusiveScanCpu ( const Oro::GpuMemory<int >& countsGpu, Oro::GpuMemory<int >& offsetsGpu ) const noexcept
128- {
129- const auto buffer_size = countsGpu.size ();
130-
131- std::vector<int > counts = countsGpu.getData ();
132- std::vector<int > offsets ( buffer_size );
133-
134- int sum = 0 ;
135- for ( int i = 0 ; i < counts.size (); ++i )
136- {
137- offsets[i] = sum;
138- sum += counts[i];
139- }
140-
141- offsetsGpu.copyFromHost ( offsets.data (), std::size ( offsets ) );
142- }
143-
144130void RadixSort::compileKernels ( const std::string& kernelPath, const std::string& includeDir ) noexcept
145131{
146132 static constexpr auto defaultKernelPath{ " ../ParallelPrimitives/RadixSortKernels.h" };
@@ -172,77 +158,38 @@ void RadixSort::compileKernels( const std::string& kernelPath, const std::string
172158 binaryPath = getCurrentDir ();
173159 binaryPath += isAmd ? " oro_compiled_kernels.hipfb" : " oro_compiled_kernels.fatbin" ;
174160 log = " loading pre-compiled kernels at path : " + binaryPath;
175-
176- m_num_threads_per_block_for_count = DEFAULT_COUNT_BLOCK_SIZE;
177- m_num_threads_per_block_for_scan = DEFAULT_SCAN_BLOCK_SIZE;
178- m_num_threads_per_block_for_sort = DEFAULT_SORT_BLOCK_SIZE;
179-
180- m_warp_size = DEFAULT_WARP_SIZE;
181161 }
182162 else
183163 {
184164 log = " compiling kernels at path : " + currentKernelPath + " in : " + currentIncludeDir;
185-
186- m_num_threads_per_block_for_count = m_props.maxThreadsPerBlock > 0 ? m_props.maxThreadsPerBlock : DEFAULT_COUNT_BLOCK_SIZE;
187- m_num_threads_per_block_for_scan = m_props.maxThreadsPerBlock > 0 ? m_props.maxThreadsPerBlock : DEFAULT_SCAN_BLOCK_SIZE;
188- m_num_threads_per_block_for_sort = m_props.maxThreadsPerBlock > 0 ? m_props.maxThreadsPerBlock : DEFAULT_SORT_BLOCK_SIZE;
189-
190- m_warp_size = ( m_props.warpSize != 0 ) ? m_props.warpSize : DEFAULT_WARP_SIZE;
191-
192- assert ( m_num_threads_per_block_for_count % m_warp_size == 0 );
193- assert ( m_num_threads_per_block_for_scan % m_warp_size == 0 );
194- assert ( m_num_threads_per_block_for_sort % m_warp_size == 0 );
195165 }
196166
197- m_num_warps_per_block_for_sort = m_num_threads_per_block_for_sort / m_warp_size;
198-
199167 if ( m_flags == Flag::LOG )
200168 {
201169 std::cout << log << std::endl;
202170 }
203171
172+ const auto includeArg{ " -I" + currentIncludeDir };
173+ std::vector<const char *> opts;
174+ opts.push_back ( includeArg.c_str () );
175+
204176 struct Record
205177 {
206178 std::string kernelName;
207179 Kernel kernelType;
208180 };
209181
210- const std::vector<Record> records{
211- { " CountKernel" , Kernel::COUNT }, { " ParallelExclusiveScanSingleWG" , Kernel::SCAN_SINGLE_WG }, { " ParallelExclusiveScanAllWG" , Kernel::SCAN_PARALLEL }, { " SortKernel" , Kernel::SORT },
212- { " SortKVKernel" , Kernel::SORT_KV }, { " SortSinglePassKernel" , Kernel::SORT_SINGLE_PASS }, { " SortSinglePassKVKernel" , Kernel::SORT_SINGLE_PASS_KV },
213- };
214-
215- const auto includeArg{ " -I" + currentIncludeDir };
216- const auto overwrite_flag = " -DOVERWRITE" ;
217- const auto count_block_size_param = " -DCOUNT_WG_SIZE_VAL=" + std::to_string ( m_num_threads_per_block_for_count );
218- const auto scan_block_size_param = " -DSCAN_WG_SIZE_VAL=" + std::to_string ( m_num_threads_per_block_for_scan );
219- const auto sort_block_size_param = " -DSORT_WG_SIZE_VAL=" + std::to_string ( m_num_threads_per_block_for_sort );
220- const auto sort_num_warps_param = " -DSORT_NUM_WARPS_PER_BLOCK_VAL=" + std::to_string ( m_num_warps_per_block_for_sort );
221-
222- std::vector<const char *> opts;
223-
224- if ( const std::string device_name = m_props.name ; device_name.find ( " NVIDIA" ) != std::string::npos )
225- {
226- opts.push_back ( " --use_fast_math" );
227- }
228- else
229- {
230- opts.push_back ( " -ffast-math" );
231- }
232-
233- opts.push_back ( includeArg.c_str () );
234- opts.push_back ( overwrite_flag );
235- opts.push_back ( count_block_size_param.c_str () );
236- opts.push_back ( scan_block_size_param.c_str () );
237- opts.push_back ( sort_block_size_param.c_str () );
238- opts.push_back ( sort_num_warps_param.c_str () );
239-
182+ const std::vector<Record> records{ { " SortSinglePassKernel" , Kernel::SORT_SINGLE_PASS },
183+ { " SortSinglePassKVKernel" , Kernel::SORT_SINGLE_PASS_KV },
184+ { " GHistogram" , Kernel::SORT_GHISTOGRAM },
185+ { " OnesweepReorderKey64" , Kernel::SORT_ONESWEEP_REORDER_KEY_64 },
186+ { " OnesweepReorderKeyPair64" , Kernel::SORT_ONESWEEP_REORDER_KEY_PAIR_64 } };
240187
241188 for ( const auto & record : records )
242189 {
243190 if constexpr ( usePrecompiledAndBakedKernel )
244191 {
245- oroFunctions[record.kernelType ] = m_oroutils.getFunctionFromPrecompiledBinary_asData (oro_compiled_kernels_h, oro_compiled_kernels_h_size, record.kernelName .c_str () );
192+ oroFunctions[record.kernelType ] = m_oroutils.getFunctionFromPrecompiledBinary_asData ( oro_compiled_kernels_h, oro_compiled_kernels_h_size, record.kernelName .c_str () );
246193 }
247194 else if constexpr ( useBakeKernel )
248195 {
@@ -262,120 +209,101 @@ void RadixSort::compileKernels( const std::string& kernelPath, const std::string
262209 printKernelInfo ( record.kernelName , oroFunctions[record.kernelType ] );
263210 }
264211 }
265-
266- return ;
267- }
268-
269- int RadixSort::calculateWGsToExecute ( const int blockSize ) const noexcept
270- {
271- const int warpPerWG = blockSize / m_warp_size;
272- const int warpPerWGP = m_props.maxThreadsPerMultiProcessor / m_warp_size;
273- const int occupancyFromWarp = ( warpPerWGP > 0 ) ? ( warpPerWGP / warpPerWG ) : 1 ;
274-
275- const int occupancy = std::max ( 1 , occupancyFromWarp );
276-
277- if ( m_flags == Flag::LOG )
278- {
279- std::cout << " Occupancy: " << occupancy << ' \n ' ;
280- }
281-
282- static constexpr auto min_num_blocks = 16 ;
283- auto number_of_blocks = m_props.multiProcessorCount > 0 ? m_props.multiProcessorCount * occupancy : min_num_blocks;
284-
285- if ( m_num_threads_per_block_for_scan > BIN_SIZE )
286- {
287- // Note: both are divisible by 2
288- const auto base = m_num_threads_per_block_for_scan / BIN_SIZE;
289-
290- // Floor
291- number_of_blocks = ( number_of_blocks / base ) * base;
292- }
293-
294- return number_of_blocks;
295212}
296213
297214void RadixSort::configure ( const std::string& kernelPath, const std::string& includeDir, oroStream stream ) noexcept
298215{
299216 compileKernels ( kernelPath, includeDir );
300217
301- m_num_blocks_for_count = calculateWGsToExecute ( m_num_threads_per_block_for_count );
218+ constexpr bool enable_copying = false ;
219+ constexpr auto key_type_size = sizeof ( std::remove_pointer_t <decltype ( KeyValueSoA::key )> );
302220
303- // / The tmp buffer size of the count kernel and the scan kernel.
221+ constexpr u64 gpSumBuffer = sizeof ( u32 ) * BIN_SIZE * key_type_size;
222+ m_gpSumBuffer.resizeAsync ( gpSumBuffer, enable_copying /* copy*/ , stream );
304223
305- const auto tmp_buffer_size = BIN_SIZE * m_num_blocks_for_count;
224+ u64 lookBackBuffer = sizeof ( u64 ) * ( BIN_SIZE * LOOKBACK_TABLE_SIZE );
225+ m_lookbackBuffer.resizeAsync ( lookBackBuffer, enable_copying /* copy*/ , stream );
306226
307- // / @c tmp_buffer_size must be divisible by @c m_num_threads_per_block_for_scan
308- // / This is guaranteed since @c m_num_blocks_for_count will be adjusted accordingly
309-
310- m_num_blocks_for_scan = tmp_buffer_size / m_num_threads_per_block_for_scan;
311-
312- m_tmp_buffer.resizeAsync ( tmp_buffer_size, false , stream );
313-
314- if ( selectedScanAlgo == ScanAlgo::SCAN_GPU_PARALLEL )
315- {
316- // These are for the scan kernel
317- m_partial_sum.resizeAsync ( m_num_blocks_for_scan, false , stream );
318- m_is_ready.resizeAsync ( m_num_blocks_for_scan, false , stream );
319- m_is_ready.resetAsync ( stream );
320- }
227+ m_tailIterator.resizeAsync ( 1 , enable_copying /* copy*/ , stream );
228+ m_tailIterator.resetAsync ( stream );
229+ m_gpSumCounter.resizeAsync ( 1 , enable_copying /* copy*/ , stream );
321230}
322231void RadixSort::setFlag ( Flag flag ) noexcept { m_flags = flag; }
323232
324- void RadixSort::sort ( const KeyValueSoA src, const KeyValueSoA dst, int n, int startBit, int endBit, oroStream stream ) noexcept
233+ void RadixSort::sort ( const KeyValueSoA& src, const KeyValueSoA& dst, uint32_t n, int startBit, int endBit, oroStream stream ) noexcept
325234{
235+ bool keyPair = src.value != nullptr ;
236+
326237 // todo. better to compute SINGLE_SORT_N_ITEMS_PER_WI which we use in the kernel dynamically rather than hard coding it to distribute the work evenly
327238 // right now, setting this as large as possible is faster than multi pass sorting
328239 if ( n < SINGLE_SORT_WG_SIZE * SINGLE_SORT_N_ITEMS_PER_WI )
329240 {
330- const auto func = oroFunctions[Kernel::SORT_SINGLE_PASS_KV];
331- const void * args[] = { &src.key , &src.value , &dst.key , &dst.value , &n, &startBit, &endBit };
332- OrochiUtils::launch1D ( func, SINGLE_SORT_WG_SIZE, args, SINGLE_SORT_WG_SIZE, 0 , stream );
241+ if ( keyPair )
242+ {
243+ const auto func = oroFunctions[Kernel::SORT_SINGLE_PASS_KV];
244+ const void * args[] = { &src.key , &src.value , &dst.key , &dst.value , &n, &startBit, &endBit };
245+ OrochiUtils::launch1D ( func, SINGLE_SORT_WG_SIZE, args, SINGLE_SORT_WG_SIZE, 0 , stream );
246+ }
247+ else
248+ {
249+ const auto func = oroFunctions[Kernel::SORT_SINGLE_PASS];
250+ const void * args[] = { &src, &dst, &n, &startBit, &endBit };
251+ OrochiUtils::launch1D ( func, SINGLE_SORT_WG_SIZE, args, SINGLE_SORT_WG_SIZE, 0 , stream );
252+ }
333253 return ;
334254 }
335255
336- auto * s{ &src };
337- auto * d{ &dst };
256+ constexpr uint64_t bit_per_iteration = 8ULL ;
338257
339- for ( int i = startBit; i < endBit; i += N_RADIX )
340- {
341- sort1pass ( *s, *d, n, i, i + std::min ( N_RADIX, endBit - i ), stream );
258+ int nIteration = div_round_up64 ( endBit - startBit, bit_per_iteration );
259+ uint64_t numberOfBlocks = div_round_up64 ( n, RADIX_SORT_BLOCK_SIZE );
342260
343- std::swap ( s, d );
344- }
261+ m_lookbackBuffer.resetAsync ( stream );
262+ m_gpSumCounter.resetAsync ( stream );
263+ m_gpSumBuffer.resetAsync ( stream );
345264
346- if ( s == &src )
265+ // counter for gHistogram.
347266 {
348- OrochiUtils::copyDtoDAsync ( dst.key , src.key , n, stream );
349- OrochiUtils::copyDtoDAsync ( dst.value , src.value , n, stream );
350- }
351- }
267+ int maxBlocksPerMP = 0 ;
268+ oroError e = oroModuleOccupancyMaxActiveBlocksPerMultiprocessor ( &maxBlocksPerMP, oroFunctions[Kernel::SORT_GHISTOGRAM], GHISTOGRAM_THREADS_PER_BLOCK, 0 );
269+ const int nBlocks = e == oroSuccess ? maxBlocksPerMP * m_props.multiProcessorCount : 2048 ;
352270
353- void RadixSort::sort ( const u32 * src, const u32 * dst, int n, int startBit, int endBit, oroStream stream ) noexcept
354- {
355- // todo. better to compute SINGLE_SORT_N_ITEMS_PER_WI which we use in the kernel dynamically rather than hard coding it to distribute the work evenly
356- // right now, setting this as large as possible is faster than multi pass sorting
357- if ( n < SINGLE_SORT_WG_SIZE * SINGLE_SORT_N_ITEMS_PER_WI )
358- {
359- const auto func = oroFunctions[Kernel::SORT_SINGLE_PASS];
360- const void * args[] = { &src, &dst, &n, &startBit, &endBit };
361- OrochiUtils::launch1D ( func, SINGLE_SORT_WG_SIZE, args, SINGLE_SORT_WG_SIZE, 0 , stream );
362- return ;
271+ const void * args[] = { &src.key , &n, arg_cast ( m_gpSumBuffer.address () ), &startBit, arg_cast ( m_gpSumCounter.address () ) };
272+ OrochiUtils::launch1D ( oroFunctions[Kernel::SORT_GHISTOGRAM], nBlocks * GHISTOGRAM_THREADS_PER_BLOCK, args, GHISTOGRAM_THREADS_PER_BLOCK, 0 , stream );
363273 }
364274
365- auto * s{ &src };
366- auto * d{ &dst };
367-
368- for ( int i = startBit; i < endBit; i += N_RADIX )
275+ auto s = src;
276+ auto d = dst;
277+ for ( int i = 0 ; i < nIteration; ++i )
369278 {
370- sort1pass ( *s, *d, n, i, i + std::min ( N_RADIX, endBit - i ), stream );
279+ if ( numberOfBlocks < LOOKBACK_TABLE_SIZE * 2 )
280+ {
281+ m_lookbackBuffer.resetAsync ( stream );
282+ } // other wise, we can skip zero clear look back buffer
371283
284+ if ( keyPair )
285+ {
286+ const void * args[] = { &s.key , &d.key , &s.value , &d.value , &n, arg_cast ( m_gpSumBuffer.address () ), arg_cast ( m_lookbackBuffer.address () ), arg_cast ( m_tailIterator.address () ), &startBit, &i };
287+ OrochiUtils::launch1D ( oroFunctions[Kernel::SORT_ONESWEEP_REORDER_KEY_PAIR_64], numberOfBlocks * REORDER_NUMBER_OF_THREADS_PER_BLOCK, args, REORDER_NUMBER_OF_THREADS_PER_BLOCK, 0 , stream );
288+ }
289+ else
290+ {
291+ const void * args[] = { &s.key , &d.key , &n, arg_cast ( m_gpSumBuffer.address () ), arg_cast ( m_lookbackBuffer.address () ), arg_cast ( m_tailIterator.address () ), &startBit, &i };
292+ OrochiUtils::launch1D ( oroFunctions[Kernel::SORT_ONESWEEP_REORDER_KEY_64], numberOfBlocks * REORDER_NUMBER_OF_THREADS_PER_BLOCK, args, REORDER_NUMBER_OF_THREADS_PER_BLOCK, 0 , stream );
293+ }
372294 std::swap ( s, d );
373295 }
374296
375- if ( s == & src )
297+ if ( s. key == src. key )
376298 {
377- OrochiUtils::copyDtoDAsync ( dst, src, n, stream );
299+ m_oroutils.copyDtoDAsync ( dst.key , src.key , n, stream );
300+
301+ if ( keyPair )
302+ {
303+ m_oroutils.copyDtoDAsync ( dst.value , src.value , n, stream );
304+ }
378305 }
379306}
380307
308+ void RadixSort::sort ( u32 * src, u32 * dst, uint32_t n, int startBit, int endBit, oroStream stream ) noexcept { sort ( KeyValueSoA{ src, nullptr }, KeyValueSoA{ dst, nullptr }, n, startBit, endBit, stream ); }
381309}; // namespace Oro
0 commit comments