Skip to content

Commit adb8668

Browse files
authored
Merge pull request #117 from GPUOpen-LibrariesAndSDKs/feature/ORO-0-radix-sort-one-sweep
Radix sort one-sweep
2 parents 2668a99 + e8e2d14 commit adb8668

6 files changed

Lines changed: 476 additions & 853 deletions

File tree

ParallelPrimitives/RadixSort.cpp

Lines changed: 93 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -42,47 +42,50 @@ namespace hip
4242
{
4343
static const char** RadixSortKernelsArgs = nullptr;
4444
static const char** RadixSortKernelsIncludes = nullptr;
45-
}
45+
} // namespace hip
4646
#endif
4747

4848
#if defined( __GNUC__ )
4949
#include <dlfcn.h>
5050
#endif
5151

52-
#if defined( ORO_PRECOMPILED ) && defined( ORO_PP_LOAD_FROM_STRING )
52+
#if defined( ORO_PRECOMPILED ) && defined( ORO_PP_LOAD_FROM_STRING )
5353
#include <ParallelPrimitives/cache/oro_compiled_kernels.h> // generate this header with 'convert_binary_to_array.py'
5454
#else
5555
const unsigned char oro_compiled_kernels_h[] = "";
5656
const size_t oro_compiled_kernels_h_size = 0;
5757
#endif
5858

59+
constexpr uint64_t div_round_up64( uint64_t val, uint64_t divisor ) noexcept { return ( val + divisor - 1 ) / divisor; }
60+
constexpr uint64_t next_multiple64( uint64_t val, uint64_t divisor ) noexcept { return div_round_up64( val, divisor ) * divisor; }
61+
5962
namespace
6063
{
6164

6265
// if those 2 preprocessors are enabled, this activates the 'usePrecompiledAndBakedKernel' mode.
63-
#if defined( ORO_PRECOMPILED ) && defined( ORO_PP_LOAD_FROM_STRING )
64-
65-
// this flag means that we bake the precompiled kernels
66-
constexpr auto usePrecompiledAndBakedKernel = true;
66+
#if defined( ORO_PRECOMPILED ) && defined( ORO_PP_LOAD_FROM_STRING )
67+
68+
// this flag means that we bake the precompiled kernels
69+
constexpr auto usePrecompiledAndBakedKernel = true;
6770

68-
constexpr auto useBitCode = false;
69-
constexpr auto useBakeKernel = false;
71+
constexpr auto useBitCode = false;
72+
constexpr auto useBakeKernel = false;
7073

7174
#else
7275

73-
constexpr auto usePrecompiledAndBakedKernel = false;
76+
constexpr auto usePrecompiledAndBakedKernel = false;
7477

75-
#if defined( ORO_PRECOMPILED )
76-
constexpr auto useBitCode = true; // this flag means we use the bitcode file
77-
#else
78-
constexpr auto useBitCode = false;
79-
#endif
78+
#if defined( ORO_PRECOMPILED )
79+
constexpr auto useBitCode = true; // this flag means we use the bitcode file
80+
#else
81+
constexpr auto useBitCode = false;
82+
#endif
8083

81-
#if defined( ORO_PP_LOAD_FROM_STRING )
82-
constexpr auto useBakeKernel = true; // this flag means we use the HIP source code embeded in the binary ( as a string )
83-
#else
84-
constexpr auto useBakeKernel = false;
85-
#endif
84+
#if defined( ORO_PP_LOAD_FROM_STRING )
85+
constexpr auto useBakeKernel = true; // this flag means we use the HIP source code embeded in the binary ( as a string )
86+
#else
87+
constexpr auto useBakeKernel = false;
88+
#endif
8689

8790
#endif
8891

@@ -124,23 +127,6 @@ RadixSort::RadixSort( oroDevice device, OrochiUtils& oroutils, oroStream stream,
124127
configure( kernelPath, includeDir, stream );
125128
}
126129

127-
void RadixSort::exclusiveScanCpu( const Oro::GpuMemory<int>& countsGpu, Oro::GpuMemory<int>& offsetsGpu ) const noexcept
128-
{
129-
const auto buffer_size = countsGpu.size();
130-
131-
std::vector<int> counts = countsGpu.getData();
132-
std::vector<int> offsets( buffer_size );
133-
134-
int sum = 0;
135-
for( int i = 0; i < counts.size(); ++i )
136-
{
137-
offsets[i] = sum;
138-
sum += counts[i];
139-
}
140-
141-
offsetsGpu.copyFromHost( offsets.data(), std::size( offsets ) );
142-
}
143-
144130
void RadixSort::compileKernels( const std::string& kernelPath, const std::string& includeDir ) noexcept
145131
{
146132
static constexpr auto defaultKernelPath{ "../ParallelPrimitives/RadixSortKernels.h" };
@@ -172,77 +158,38 @@ void RadixSort::compileKernels( const std::string& kernelPath, const std::string
172158
binaryPath = getCurrentDir();
173159
binaryPath += isAmd ? "oro_compiled_kernels.hipfb" : "oro_compiled_kernels.fatbin";
174160
log = "loading pre-compiled kernels at path : " + binaryPath;
175-
176-
m_num_threads_per_block_for_count = DEFAULT_COUNT_BLOCK_SIZE;
177-
m_num_threads_per_block_for_scan = DEFAULT_SCAN_BLOCK_SIZE;
178-
m_num_threads_per_block_for_sort = DEFAULT_SORT_BLOCK_SIZE;
179-
180-
m_warp_size = DEFAULT_WARP_SIZE;
181161
}
182162
else
183163
{
184164
log = "compiling kernels at path : " + currentKernelPath + " in : " + currentIncludeDir;
185-
186-
m_num_threads_per_block_for_count = m_props.maxThreadsPerBlock > 0 ? m_props.maxThreadsPerBlock : DEFAULT_COUNT_BLOCK_SIZE;
187-
m_num_threads_per_block_for_scan = m_props.maxThreadsPerBlock > 0 ? m_props.maxThreadsPerBlock : DEFAULT_SCAN_BLOCK_SIZE;
188-
m_num_threads_per_block_for_sort = m_props.maxThreadsPerBlock > 0 ? m_props.maxThreadsPerBlock : DEFAULT_SORT_BLOCK_SIZE;
189-
190-
m_warp_size = ( m_props.warpSize != 0 ) ? m_props.warpSize : DEFAULT_WARP_SIZE;
191-
192-
assert( m_num_threads_per_block_for_count % m_warp_size == 0 );
193-
assert( m_num_threads_per_block_for_scan % m_warp_size == 0 );
194-
assert( m_num_threads_per_block_for_sort % m_warp_size == 0 );
195165
}
196166

197-
m_num_warps_per_block_for_sort = m_num_threads_per_block_for_sort / m_warp_size;
198-
199167
if( m_flags == Flag::LOG )
200168
{
201169
std::cout << log << std::endl;
202170
}
203171

172+
const auto includeArg{ "-I" + currentIncludeDir };
173+
std::vector<const char*> opts;
174+
opts.push_back( includeArg.c_str() );
175+
204176
struct Record
205177
{
206178
std::string kernelName;
207179
Kernel kernelType;
208180
};
209181

210-
const std::vector<Record> records{
211-
{ "CountKernel", Kernel::COUNT }, { "ParallelExclusiveScanSingleWG", Kernel::SCAN_SINGLE_WG }, { "ParallelExclusiveScanAllWG", Kernel::SCAN_PARALLEL }, { "SortKernel", Kernel::SORT },
212-
{ "SortKVKernel", Kernel::SORT_KV }, { "SortSinglePassKernel", Kernel::SORT_SINGLE_PASS }, { "SortSinglePassKVKernel", Kernel::SORT_SINGLE_PASS_KV },
213-
};
214-
215-
const auto includeArg{ "-I" + currentIncludeDir };
216-
const auto overwrite_flag = "-DOVERWRITE";
217-
const auto count_block_size_param = "-DCOUNT_WG_SIZE_VAL=" + std::to_string( m_num_threads_per_block_for_count );
218-
const auto scan_block_size_param = "-DSCAN_WG_SIZE_VAL=" + std::to_string( m_num_threads_per_block_for_scan );
219-
const auto sort_block_size_param = "-DSORT_WG_SIZE_VAL=" + std::to_string( m_num_threads_per_block_for_sort );
220-
const auto sort_num_warps_param = "-DSORT_NUM_WARPS_PER_BLOCK_VAL=" + std::to_string( m_num_warps_per_block_for_sort );
221-
222-
std::vector<const char*> opts;
223-
224-
if( const std::string device_name = m_props.name; device_name.find( "NVIDIA" ) != std::string::npos )
225-
{
226-
opts.push_back( "--use_fast_math" );
227-
}
228-
else
229-
{
230-
opts.push_back( "-ffast-math" );
231-
}
232-
233-
opts.push_back( includeArg.c_str() );
234-
opts.push_back( overwrite_flag );
235-
opts.push_back( count_block_size_param.c_str() );
236-
opts.push_back( scan_block_size_param.c_str() );
237-
opts.push_back( sort_block_size_param.c_str() );
238-
opts.push_back( sort_num_warps_param.c_str() );
239-
182+
const std::vector<Record> records{ { "SortSinglePassKernel", Kernel::SORT_SINGLE_PASS },
183+
{ "SortSinglePassKVKernel", Kernel::SORT_SINGLE_PASS_KV },
184+
{ "GHistogram", Kernel::SORT_GHISTOGRAM },
185+
{ "OnesweepReorderKey64", Kernel::SORT_ONESWEEP_REORDER_KEY_64 },
186+
{ "OnesweepReorderKeyPair64", Kernel::SORT_ONESWEEP_REORDER_KEY_PAIR_64 } };
240187

241188
for( const auto& record : records )
242189
{
243190
if constexpr( usePrecompiledAndBakedKernel )
244191
{
245-
oroFunctions[record.kernelType] = m_oroutils.getFunctionFromPrecompiledBinary_asData(oro_compiled_kernels_h, oro_compiled_kernels_h_size, record.kernelName.c_str() );
192+
oroFunctions[record.kernelType] = m_oroutils.getFunctionFromPrecompiledBinary_asData( oro_compiled_kernels_h, oro_compiled_kernels_h_size, record.kernelName.c_str() );
246193
}
247194
else if constexpr( useBakeKernel )
248195
{
@@ -262,120 +209,101 @@ void RadixSort::compileKernels( const std::string& kernelPath, const std::string
262209
printKernelInfo( record.kernelName, oroFunctions[record.kernelType] );
263210
}
264211
}
265-
266-
return;
267-
}
268-
269-
int RadixSort::calculateWGsToExecute( const int blockSize ) const noexcept
270-
{
271-
const int warpPerWG = blockSize / m_warp_size;
272-
const int warpPerWGP = m_props.maxThreadsPerMultiProcessor / m_warp_size;
273-
const int occupancyFromWarp = ( warpPerWGP > 0 ) ? ( warpPerWGP / warpPerWG ) : 1;
274-
275-
const int occupancy = std::max( 1, occupancyFromWarp );
276-
277-
if( m_flags == Flag::LOG )
278-
{
279-
std::cout << "Occupancy: " << occupancy << '\n';
280-
}
281-
282-
static constexpr auto min_num_blocks = 16;
283-
auto number_of_blocks = m_props.multiProcessorCount > 0 ? m_props.multiProcessorCount * occupancy : min_num_blocks;
284-
285-
if( m_num_threads_per_block_for_scan > BIN_SIZE )
286-
{
287-
// Note: both are divisible by 2
288-
const auto base = m_num_threads_per_block_for_scan / BIN_SIZE;
289-
290-
// Floor
291-
number_of_blocks = ( number_of_blocks / base ) * base;
292-
}
293-
294-
return number_of_blocks;
295212
}
296213

297214
void RadixSort::configure( const std::string& kernelPath, const std::string& includeDir, oroStream stream ) noexcept
298215
{
299216
compileKernels( kernelPath, includeDir );
300217

301-
m_num_blocks_for_count = calculateWGsToExecute( m_num_threads_per_block_for_count );
218+
constexpr bool enable_copying = false;
219+
constexpr auto key_type_size = sizeof( std::remove_pointer_t<decltype( KeyValueSoA::key )> );
302220

303-
/// The tmp buffer size of the count kernel and the scan kernel.
221+
constexpr u64 gpSumBuffer = sizeof( u32 ) * BIN_SIZE * key_type_size;
222+
m_gpSumBuffer.resizeAsync( gpSumBuffer, enable_copying /*copy*/, stream );
304223

305-
const auto tmp_buffer_size = BIN_SIZE * m_num_blocks_for_count;
224+
u64 lookBackBuffer = sizeof( u64 ) * ( BIN_SIZE * LOOKBACK_TABLE_SIZE );
225+
m_lookbackBuffer.resizeAsync( lookBackBuffer, enable_copying /*copy*/, stream );
306226

307-
/// @c tmp_buffer_size must be divisible by @c m_num_threads_per_block_for_scan
308-
/// This is guaranteed since @c m_num_blocks_for_count will be adjusted accordingly
309-
310-
m_num_blocks_for_scan = tmp_buffer_size / m_num_threads_per_block_for_scan;
311-
312-
m_tmp_buffer.resizeAsync( tmp_buffer_size, false, stream );
313-
314-
if( selectedScanAlgo == ScanAlgo::SCAN_GPU_PARALLEL )
315-
{
316-
// These are for the scan kernel
317-
m_partial_sum.resizeAsync( m_num_blocks_for_scan, false, stream );
318-
m_is_ready.resizeAsync( m_num_blocks_for_scan, false, stream );
319-
m_is_ready.resetAsync( stream );
320-
}
227+
m_tailIterator.resizeAsync( 1, enable_copying /*copy*/, stream );
228+
m_tailIterator.resetAsync( stream );
229+
m_gpSumCounter.resizeAsync( 1, enable_copying /*copy*/, stream );
321230
}
322231
void RadixSort::setFlag( Flag flag ) noexcept { m_flags = flag; }
323232

324-
void RadixSort::sort( const KeyValueSoA src, const KeyValueSoA dst, int n, int startBit, int endBit, oroStream stream ) noexcept
233+
void RadixSort::sort( const KeyValueSoA& src, const KeyValueSoA& dst, uint32_t n, int startBit, int endBit, oroStream stream ) noexcept
325234
{
235+
bool keyPair = src.value != nullptr;
236+
326237
// todo. better to compute SINGLE_SORT_N_ITEMS_PER_WI which we use in the kernel dynamically rather than hard coding it to distribute the work evenly
327238
// right now, setting this as large as possible is faster than multi pass sorting
328239
if( n < SINGLE_SORT_WG_SIZE * SINGLE_SORT_N_ITEMS_PER_WI )
329240
{
330-
const auto func = oroFunctions[Kernel::SORT_SINGLE_PASS_KV];
331-
const void* args[] = { &src.key, &src.value, &dst.key, &dst.value, &n, &startBit, &endBit };
332-
OrochiUtils::launch1D( func, SINGLE_SORT_WG_SIZE, args, SINGLE_SORT_WG_SIZE, 0, stream );
241+
if( keyPair )
242+
{
243+
const auto func = oroFunctions[Kernel::SORT_SINGLE_PASS_KV];
244+
const void* args[] = { &src.key, &src.value, &dst.key, &dst.value, &n, &startBit, &endBit };
245+
OrochiUtils::launch1D( func, SINGLE_SORT_WG_SIZE, args, SINGLE_SORT_WG_SIZE, 0, stream );
246+
}
247+
else
248+
{
249+
const auto func = oroFunctions[Kernel::SORT_SINGLE_PASS];
250+
const void* args[] = { &src, &dst, &n, &startBit, &endBit };
251+
OrochiUtils::launch1D( func, SINGLE_SORT_WG_SIZE, args, SINGLE_SORT_WG_SIZE, 0, stream );
252+
}
333253
return;
334254
}
335255

336-
auto* s{ &src };
337-
auto* d{ &dst };
256+
constexpr uint64_t bit_per_iteration = 8ULL;
338257

339-
for( int i = startBit; i < endBit; i += N_RADIX )
340-
{
341-
sort1pass( *s, *d, n, i, i + std::min( N_RADIX, endBit - i ), stream );
258+
int nIteration = div_round_up64( endBit - startBit, bit_per_iteration );
259+
uint64_t numberOfBlocks = div_round_up64( n, RADIX_SORT_BLOCK_SIZE );
342260

343-
std::swap( s, d );
344-
}
261+
m_lookbackBuffer.resetAsync( stream );
262+
m_gpSumCounter.resetAsync( stream );
263+
m_gpSumBuffer.resetAsync( stream );
345264

346-
if( s == &src )
265+
// counter for gHistogram.
347266
{
348-
OrochiUtils::copyDtoDAsync( dst.key, src.key, n, stream );
349-
OrochiUtils::copyDtoDAsync( dst.value, src.value, n, stream );
350-
}
351-
}
267+
int maxBlocksPerMP = 0;
268+
oroError e = oroModuleOccupancyMaxActiveBlocksPerMultiprocessor( &maxBlocksPerMP, oroFunctions[Kernel::SORT_GHISTOGRAM], GHISTOGRAM_THREADS_PER_BLOCK, 0 );
269+
const int nBlocks = e == oroSuccess ? maxBlocksPerMP * m_props.multiProcessorCount : 2048;
352270

353-
void RadixSort::sort( const u32* src, const u32* dst, int n, int startBit, int endBit, oroStream stream ) noexcept
354-
{
355-
// todo. better to compute SINGLE_SORT_N_ITEMS_PER_WI which we use in the kernel dynamically rather than hard coding it to distribute the work evenly
356-
// right now, setting this as large as possible is faster than multi pass sorting
357-
if( n < SINGLE_SORT_WG_SIZE * SINGLE_SORT_N_ITEMS_PER_WI )
358-
{
359-
const auto func = oroFunctions[Kernel::SORT_SINGLE_PASS];
360-
const void* args[] = { &src, &dst, &n, &startBit, &endBit };
361-
OrochiUtils::launch1D( func, SINGLE_SORT_WG_SIZE, args, SINGLE_SORT_WG_SIZE, 0, stream );
362-
return;
271+
const void* args[] = { &src.key, &n, arg_cast( m_gpSumBuffer.address() ), &startBit, arg_cast( m_gpSumCounter.address() ) };
272+
OrochiUtils::launch1D( oroFunctions[Kernel::SORT_GHISTOGRAM], nBlocks * GHISTOGRAM_THREADS_PER_BLOCK, args, GHISTOGRAM_THREADS_PER_BLOCK, 0, stream );
363273
}
364274

365-
auto* s{ &src };
366-
auto* d{ &dst };
367-
368-
for( int i = startBit; i < endBit; i += N_RADIX )
275+
auto s = src;
276+
auto d = dst;
277+
for( int i = 0; i < nIteration; ++i )
369278
{
370-
sort1pass( *s, *d, n, i, i + std::min( N_RADIX, endBit - i ), stream );
279+
if( numberOfBlocks < LOOKBACK_TABLE_SIZE * 2 )
280+
{
281+
m_lookbackBuffer.resetAsync( stream );
282+
} // other wise, we can skip zero clear look back buffer
371283

284+
if( keyPair )
285+
{
286+
const void* args[] = { &s.key, &d.key, &s.value, &d.value, &n, arg_cast( m_gpSumBuffer.address() ), arg_cast( m_lookbackBuffer.address() ), arg_cast( m_tailIterator.address() ), &startBit, &i };
287+
OrochiUtils::launch1D( oroFunctions[Kernel::SORT_ONESWEEP_REORDER_KEY_PAIR_64], numberOfBlocks * REORDER_NUMBER_OF_THREADS_PER_BLOCK, args, REORDER_NUMBER_OF_THREADS_PER_BLOCK, 0, stream );
288+
}
289+
else
290+
{
291+
const void* args[] = { &s.key, &d.key, &n, arg_cast( m_gpSumBuffer.address() ), arg_cast( m_lookbackBuffer.address() ), arg_cast( m_tailIterator.address() ), &startBit, &i };
292+
OrochiUtils::launch1D( oroFunctions[Kernel::SORT_ONESWEEP_REORDER_KEY_64], numberOfBlocks * REORDER_NUMBER_OF_THREADS_PER_BLOCK, args, REORDER_NUMBER_OF_THREADS_PER_BLOCK, 0, stream );
293+
}
372294
std::swap( s, d );
373295
}
374296

375-
if( s == &src )
297+
if( s.key == src.key )
376298
{
377-
OrochiUtils::copyDtoDAsync( dst, src, n, stream );
299+
m_oroutils.copyDtoDAsync( dst.key, src.key, n, stream );
300+
301+
if( keyPair )
302+
{
303+
m_oroutils.copyDtoDAsync( dst.value, src.value, n, stream );
304+
}
378305
}
379306
}
380307

308+
void RadixSort::sort( u32* src, u32* dst, uint32_t n, int startBit, int endBit, oroStream stream ) noexcept { sort( KeyValueSoA{ src, nullptr }, KeyValueSoA{ dst, nullptr }, n, startBit, endBit, stream ); }
381309
}; // namespace Oro

0 commit comments

Comments
 (0)