@@ -157,9 +157,11 @@ void RadixSort::compileKernels( const std::string& kernelPath, const std::string
157157 Kernel kernelType;
158158 };
159159
160- const std::vector<Record> records{
161- { " SortSinglePassKernel" , Kernel::SORT_SINGLE_PASS }, { " SortSinglePassKVKernel" , Kernel::SORT_SINGLE_PASS_KV },
162- };
160+ const std::vector<Record> records{ { " SortSinglePassKernel" , Kernel::SORT_SINGLE_PASS },
161+ { " SortSinglePassKVKernel" , Kernel::SORT_SINGLE_PASS_KV },
162+ { " GHistogram" , Kernel::SORT_GHISTOGRAM },
163+ { " OnesweepReorderKey64" , Kernel::SORT_ONESWEEP_REORDER_KEY_64 },
164+ { " OnesweepReorderKeyPair64" , Kernel::SORT_ONESWEEP_REORDER_KEY_PAIR_64 } };
163165
164166
165167 for ( const auto & record : records )
@@ -187,13 +189,6 @@ void RadixSort::compileKernels( const std::string& kernelPath, const std::string
187189 }
188190 }
189191
190- // TODO: bit code support?
191- #define LOAD_FUNC ( var, kernel ) var = m_oroutils.getFunctionFromFile( m_device, currentKernelPath.c_str(), kernel, &opts );
192- LOAD_FUNC ( m_gHistogram, " gHistogram" );
193- LOAD_FUNC ( m_onesweep_reorderKey64, " onesweep_reorderKey64" );
194- LOAD_FUNC ( m_onesweep_reorderKeyPair64, " onesweep_reorderKeyPair64" );
195- #undef LOAD_FUNC
196-
197192}
198193
199194void RadixSort::configure ( const std::string& kernelPath, const std::string& includeDir, oroStream stream ) noexcept
@@ -245,11 +240,11 @@ void RadixSort::sort( const KeyValueSoA& src, const KeyValueSoA& dst, uint32_t n
245240 // counter for gHistogram.
246241 {
247242 int maxBlocksPerMP = 0 ;
248- oroError e = oroOccupancyMaxActiveBlocksPerMultiprocessor ( &maxBlocksPerMP, m_gHistogram , GHISTOGRAM_THREADS_PER_BLOCK, 0 );
243+ oroError e = oroOccupancyMaxActiveBlocksPerMultiprocessor ( &maxBlocksPerMP, oroFunctions[Kernel::SORT_GHISTOGRAM] , GHISTOGRAM_THREADS_PER_BLOCK, 0 );
249244 const int nBlocks = e == oroSuccess ? maxBlocksPerMP * m_props.multiProcessorCount : 2048 ;
250245
251246 const void * args[] = { &src.key , &n, arg_cast ( m_gpSumBuffer.address () ), &startBit, arg_cast ( m_gpSumCounter.address () ) };
252- OrochiUtils::launch1D ( m_gHistogram , nBlocks * GHISTOGRAM_THREADS_PER_BLOCK, args, GHISTOGRAM_THREADS_PER_BLOCK, 0 , stream );
247+ OrochiUtils::launch1D ( oroFunctions[Kernel::SORT_GHISTOGRAM] , nBlocks * GHISTOGRAM_THREADS_PER_BLOCK, args, GHISTOGRAM_THREADS_PER_BLOCK, 0 , stream );
253248 }
254249
255250 auto s = src;
@@ -264,12 +259,12 @@ void RadixSort::sort( const KeyValueSoA& src, const KeyValueSoA& dst, uint32_t n
264259 if ( keyPair )
265260 {
266261 const void * args[] = { &s.key , &d.key , &s.value , &d.value , &n, arg_cast ( m_gpSumBuffer.address () ), arg_cast ( m_lookbackBuffer.address () ), arg_cast ( m_tailIterator.address () ), &startBit, &i };
267- OrochiUtils::launch1D ( m_onesweep_reorderKeyPair64 , numberOfBlocks * REORDER_NUMBER_OF_THREADS_PER_BLOCK, args, REORDER_NUMBER_OF_THREADS_PER_BLOCK, 0 , stream );
262+ OrochiUtils::launch1D ( oroFunctions[Kernel::SORT_ONESWEEP_REORDER_KEY_PAIR_64] , numberOfBlocks * REORDER_NUMBER_OF_THREADS_PER_BLOCK, args, REORDER_NUMBER_OF_THREADS_PER_BLOCK, 0 , stream );
268263 }
269264 else
270265 {
271266 const void * args[] = { &s.key , &d.key , &n, arg_cast ( m_gpSumBuffer.address () ), arg_cast ( m_lookbackBuffer.address () ), arg_cast ( m_tailIterator.address () ), &startBit, &i };
272- OrochiUtils::launch1D ( m_onesweep_reorderKey64 , numberOfBlocks * REORDER_NUMBER_OF_THREADS_PER_BLOCK, args, REORDER_NUMBER_OF_THREADS_PER_BLOCK, 0 , stream );
267+ OrochiUtils::launch1D ( oroFunctions[Kernel::SORT_ONESWEEP_REORDER_KEY_64] , numberOfBlocks * REORDER_NUMBER_OF_THREADS_PER_BLOCK, args, REORDER_NUMBER_OF_THREADS_PER_BLOCK, 0 , stream );
273268 }
274269 std::swap ( s, d );
275270 }
0 commit comments