Skip to content

Commit 7ac604c

Browse files
committed
Refactor and optimize bitonic sort implementation
Refactored `LocalPasses<sortable_t, 1, Comparator>::operator()` Introduced a new `bitonic_sort_config` structure with clearer naming and additional constants (`ElementsInShared`, `R`, `Batches`) for better readability and shared memory handling. Replaced old `BitonicSort` template specializations with modular helper functions (`atomicOpPairs`, `inThreadShuffleSortPairs`, `subgroupShuffleSortPairs`, `workgroupShuffleSortPairs`) to handle different stages of the sorting algorithm. Added a `ShufflesSort` function to dynamically select shuffle/sort operations based on stride size. Refactored the `BitonicSort` structure to use the new modular functions, simplifying the `__call` method and improving the sorting logic.
1 parent 2d6bf0b commit 7ac604c

3 files changed

Lines changed: 301 additions & 249 deletions

File tree

include/nbl/builtin/hlsl/bitonic_sort/common.hlsl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
#include "nbl/builtin/hlsl/cpp_compat.hlsl"
55
#include "nbl/builtin/hlsl/functional.hlsl"
6+
#include "nbl/builtin/hlsl/subgroup/basic.hlsl"
7+
#include "nbl/builtin/hlsl/glsl_compat/subgroup_shuffle.hlsl"
68

79
namespace nbl
810
{
@@ -40,12 +42,16 @@ struct LocalPasses<sortable_t, 1, Comparator>
4042

4143
void operator()(bool ascending, sortable_t data[N], NBL_CONST_REF_ARG(Comparator) comp)
4244
{
43-
const bool shouldSwap = comp(data[1], data[0]);
44-
const bool doSwap = (shouldSwap == ascending);
45-
46-
sortable_t temp = data[0];
47-
data[0] = doSwap ? data[1] : data[0];
48-
data[1] = doSwap ? temp : data[1];
45+
// For ascending: swap if data[1] < data[0] (put smaller first)
46+
// For descending: swap if data[0] < data[1] (put larger first)
47+
const bool needSwap = ascending ? comp(data[1], data[0]) : comp(data[0], data[1]);
48+
49+
if (needSwap)
50+
{
51+
sortable_t temp = data[0];
52+
data[0] = data[1];
53+
data[1] = temp;
54+
}
4955
}
5056
};
5157

Lines changed: 73 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,83 +1,94 @@
1-
#ifndef NBL_BUILTIN_HLSL_SUBGROUP_BITONIC_SORT_INCLUDED
2-
#define NBL_BUILTIN_HLSL_SUBGROUP_BITONIC_SORT_INCLUDED
1+
#ifndef _NBL_BUILTIN_HLSL_SUBGROUP_BITONIC_SORT_INCLUDED_
2+
#define _NBL_BUILTIN_HLSL_SUBGROUP_BITONIC_SORT_INCLUDED_
3+
34
#include "nbl/builtin/hlsl/bitonic_sort/common.hlsl"
4-
#include "nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl"
5+
#include "nbl/builtin/hlsl/subgroup/basic.hlsl"
56
#include "nbl/builtin/hlsl/glsl_compat/subgroup_shuffle.hlsl"
6-
#include "nbl/builtin/hlsl/functional.hlsl"
7+
78
namespace nbl
89
{
910
namespace hlsl
1011
{
1112
namespace subgroup
1213
{
14+
namespace bitonic_sort
15+
{
16+
using namespace nbl::hlsl::bitonic_sort;
1317

14-
template<typename KeyType, typename ValueType, typename Comparator = less<KeyType> >
15-
struct bitonic_sort_config
18+
template<typename KeyType, typename Comparator, class device_capabilities = void>
19+
struct bitonic_sort_wgtype
1620
{
17-
using key_t = KeyType;
18-
using value_t = ValueType;
19-
using comparator_t = Comparator;
20-
};
21+
using WGType = WorkgroupType<KeyType>;
22+
using key_t = KeyType;
23+
using comparator_t = Comparator;
2124

22-
template<typename Config, class device_capabilities = void>
23-
struct bitonic_sort;
25+
static void mergeStage(
26+
uint32_t stage,
27+
bool bitonicAscending,
28+
uint32_t invocationID,
29+
NBL_REF_ARG(WGType) lo,
30+
NBL_REF_ARG(WGType) hi)
31+
{
32+
comparator_t comp;
2433

25-
template<typename KeyType, typename ValueType, typename Comparator, class device_capabilities>
26-
struct bitonic_sort<bitonic_sort_config<KeyType, ValueType, Comparator>, device_capabilities>
27-
{
28-
using config_t = bitonic_sort_config<KeyType, ValueType, Comparator>;
29-
using key_t = typename config_t::key_t;
30-
using value_t = typename config_t::value_t;
31-
using comparator_t = typename config_t::comparator_t;
34+
[unroll]
35+
for (uint32_t pass = 0u; pass <= stage; ++pass)
36+
{
37+
uint32_t stride = 1u << (stage - pass);
38+
uint32_t partner = stride >> 1;
39+
40+
if (partner == 0u)
41+
{
42+
bool swap = comp(hi.key, lo.key) == bitonicAscending;
43+
WGType tmp = lo;
44+
lo.key = swap ? hi.key : lo.key;
45+
lo.workgroupRelativeIndex = swap ? hi.workgroupRelativeIndex : lo.workgroupRelativeIndex;
46+
hi.key = swap ? tmp.key : hi.key;
47+
hi.workgroupRelativeIndex = swap ? tmp.workgroupRelativeIndex : hi.workgroupRelativeIndex;
48+
}
49+
else
50+
{
51+
bool isUpper = (invocationID & partner) != 0u;
3252

33-
static void mergeStage(uint32_t stage, bool bitonicAscending, uint32_t invocationID,
34-
NBL_REF_ARG(pair<key_t, value_t>) loPair, NBL_REF_ARG(pair<key_t, value_t>) hiPair)
35-
{
36-
comparator_t comp;
53+
// Select which element to trade and shuffle members individually
54+
key_t tradingKey = isUpper ? hi.key : lo.key;
55+
uint32_t tradingIdx = isUpper ? hi.workgroupRelativeIndex : lo.workgroupRelativeIndex;
3756

38-
[unroll]
39-
for (uint32_t pass = 0; pass <= stage; pass++)
40-
{
41-
const uint32_t stride = 1u << (stage - pass); // Element stride
42-
const uint32_t threadStride = stride >> 1;
43-
if (threadStride == 0)
44-
{
45-
// Local compare and swap for stage 0
46-
nbl::hlsl::bitonic_sort::compareSwap(bitonicAscending, loPair, hiPair, comp);
47-
}
48-
else
49-
{
50-
// Shuffle from partner using XOR
51-
const key_t pLoKey = glsl::subgroupShuffleXor<key_t>(loPair.first, threadStride);
52-
const value_t pLoVal = glsl::subgroupShuffleXor<value_t>(loPair.second, threadStride);
53-
const key_t pHiKey = glsl::subgroupShuffleXor<key_t>(hiPair.first, threadStride);
54-
const value_t pHiVal = glsl::subgroupShuffleXor<value_t>(hiPair.second, threadStride);
57+
tradingKey = glsl::subgroupShuffleXor(tradingKey, partner);
58+
tradingIdx = glsl::subgroupShuffleXor(tradingIdx, partner);
5559

56-
const pair<key_t, value_t> partnerLoPair = make_pair(pLoKey, pLoVal);
57-
const pair<key_t, value_t> partnerHiPair = make_pair(pHiKey, pHiVal);
60+
lo.key = isUpper ? lo.key : tradingKey;
61+
lo.workgroupRelativeIndex = isUpper ? lo.workgroupRelativeIndex : tradingIdx;
62+
hi.key = isUpper ? tradingKey : hi.key;
63+
hi.workgroupRelativeIndex = isUpper ? tradingIdx : hi.workgroupRelativeIndex;
5864

59-
const bool isUpper = bool(invocationID & threadStride);
60-
const bool takeLarger = isUpper == bitonicAscending;
65+
bool swap = comp(hi.key, lo.key) == bitonicAscending;
66+
WGType tmp = lo;
67+
lo.key = swap ? hi.key : lo.key;
68+
lo.workgroupRelativeIndex = swap ? hi.workgroupRelativeIndex : lo.workgroupRelativeIndex;
69+
hi.key = swap ? tmp.key : hi.key;
70+
hi.workgroupRelativeIndex = swap ? tmp.workgroupRelativeIndex : hi.workgroupRelativeIndex;
71+
}
72+
}
73+
}
6174

62-
nbl::hlsl::bitonic_sort::compareExchangeWithPartner(takeLarger, loPair, partnerLoPair, hiPair, partnerHiPair, comp);
63-
}
64-
}
65-
}
75+
static void __call(bool ascending, NBL_REF_ARG(WGType) lo, NBL_REF_ARG(WGType) hi)
76+
{
77+
uint32_t id = glsl::gl_SubgroupInvocationID();
78+
uint32_t log2 = glsl::gl_SubgroupSizeLog2();
6679

67-
static void __call(bool ascending, NBL_REF_ARG(pair<key_t, value_t>) loPair, NBL_REF_ARG(pair<key_t, value_t>) hiPair)
68-
{
69-
const uint32_t invocationID = glsl::gl_SubgroupInvocationID();
70-
const uint32_t subgroupSizeLog2 = glsl::gl_SubgroupSizeLog2();
71-
[unroll]
72-
for (uint32_t stage = 0; stage <= subgroupSizeLog2; stage++)
73-
{
74-
const bool bitonicAscending = (stage == subgroupSizeLog2) ? ascending : !bool(invocationID & (1u << stage));
75-
mergeStage(stage, bitonicAscending, invocationID, loPair, hiPair);
76-
}
77-
}
80+
[unroll]
81+
for (uint32_t s = 0u; s <= log2; ++s)
82+
{
83+
bool dir = (s == log2) ? ascending : ((id & (1u << s)) != 0u);
84+
mergeStage(s, dir, id, lo, hi);
85+
}
86+
}
7887
};
7988

80-
}
81-
}
82-
}
89+
} // namespace bitonic_sort
90+
} // namespace subgroup
91+
} // namespace hlsl
92+
} // namespace nbl
93+
8394
#endif

0 commit comments

Comments
 (0)