@@ -217,15 +217,18 @@ void RadixSort::configure( const std::string& kernelPath, const std::string& inc
217217{
218218 compileKernels ( kernelPath, includeDir );
219219
220- u64 gpSumBuffer = sizeof ( u32 ) * BIN_SIZE * sizeof ( u32 /* key type */ );
221- m_gpSumBuffer.resizeAsync ( gpSumBuffer, false /* copy*/ , stream );
220+ constexpr bool enable_copying = false ;
221+ constexpr auto key_type_size = sizeof (std::remove_pointer_t <decltype (KeyValueSoA::key)>);
222+
223+ constexpr u64 gpSumBuffer = sizeof ( u32 ) * BIN_SIZE * key_type_size;
224+ m_gpSumBuffer.resizeAsync ( gpSumBuffer, enable_copying /* copy*/ , stream );
222225
223226 u64 lookBackBuffer = sizeof ( u64 ) * ( BIN_SIZE * LOOKBACK_TABLE_SIZE );
224- m_lookbackBuffer.resizeAsync ( lookBackBuffer, false /* copy*/ , stream );
227+ m_lookbackBuffer.resizeAsync ( lookBackBuffer, enable_copying /* copy*/ , stream );
225228
226- m_tailIterator.resizeAsync ( 1 , false /* copy*/ , stream );
229+ m_tailIterator.resizeAsync ( 1 , enable_copying /* copy*/ , stream );
227230 m_tailIterator.resetAsync ( stream );
228- m_gpSumCounter.resizeAsync ( 1 , false /* copy*/ , stream );
231+ m_gpSumCounter.resizeAsync ( 1 , enable_copying /* copy*/ , stream );
229232}
230233void RadixSort::setFlag ( Flag flag ) noexcept { m_flags = flag; }
231234
0 commit comments