|
1 | | -#ifndef NBL_BUILTIN_HLSL_SUBGROUP_BITONIC_SORT_INCLUDED |
2 | | -#define NBL_BUILTIN_HLSL_SUBGROUP_BITONIC_SORT_INCLUDED |
| 1 | +#ifndef _NBL_BUILTIN_HLSL_SUBGROUP_BITONIC_SORT_INCLUDED_ |
| 2 | +#define _NBL_BUILTIN_HLSL_SUBGROUP_BITONIC_SORT_INCLUDED_ |
| 3 | + |
3 | 4 | #include "nbl/builtin/hlsl/bitonic_sort/common.hlsl" |
4 | | -#include "nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl" |
| 5 | +#include "nbl/builtin/hlsl/subgroup/basic.hlsl" |
5 | 6 | #include "nbl/builtin/hlsl/glsl_compat/subgroup_shuffle.hlsl" |
6 | | -#include "nbl/builtin/hlsl/functional.hlsl" |
| 7 | + |
7 | 8 | namespace nbl |
8 | 9 | { |
9 | 10 | namespace hlsl |
10 | 11 | { |
11 | 12 | namespace subgroup |
12 | 13 | { |
| 14 | +namespace bitonic_sort |
| 15 | +{ |
| 16 | +using namespace nbl::hlsl::bitonic_sort; |
13 | 17 |
|
14 | | -template<typename KeyType, typename ValueType, typename Comparator = less<KeyType> > |
15 | | -struct bitonic_sort_config |
| 18 | +template<typename KeyType, typename Comparator, class device_capabilities = void> |
| 19 | +struct bitonic_sort_wgtype |
16 | 20 | { |
17 | | - using key_t = KeyType; |
18 | | - using value_t = ValueType; |
19 | | - using comparator_t = Comparator; |
20 | | -}; |
| 21 | + using WGType = WorkgroupType<KeyType>; |
| 22 | + using key_t = KeyType; |
| 23 | + using comparator_t = Comparator; |
21 | 24 |
|
22 | | -template<typename Config, class device_capabilities = void> |
23 | | -struct bitonic_sort; |
| 25 | + static void mergeStage( |
| 26 | + uint32_t stage, |
| 27 | + bool bitonicAscending, |
| 28 | + uint32_t invocationID, |
| 29 | + NBL_REF_ARG(WGType) lo, |
| 30 | + NBL_REF_ARG(WGType) hi) |
| 31 | + { |
| 32 | + comparator_t comp; |
24 | 33 |
|
25 | | -template<typename KeyType, typename ValueType, typename Comparator, class device_capabilities> |
26 | | -struct bitonic_sort<bitonic_sort_config<KeyType, ValueType, Comparator>, device_capabilities> |
27 | | -{ |
28 | | - using config_t = bitonic_sort_config<KeyType, ValueType, Comparator>; |
29 | | - using key_t = typename config_t::key_t; |
30 | | - using value_t = typename config_t::value_t; |
31 | | - using comparator_t = typename config_t::comparator_t; |
| 34 | + [unroll] |
| 35 | + for (uint32_t pass = 0u; pass <= stage; ++pass) |
| 36 | + { |
| 37 | + uint32_t stride = 1u << (stage - pass); |
| 38 | + uint32_t partner = stride >> 1; |
| 39 | + |
| 40 | + if (partner == 0u) |
| 41 | + { |
| 42 | + bool swap = comp(hi.key, lo.key) == bitonicAscending; |
| 43 | + WGType tmp = lo; |
| 44 | + lo.key = swap ? hi.key : lo.key; |
| 45 | + lo.workgroupRelativeIndex = swap ? hi.workgroupRelativeIndex : lo.workgroupRelativeIndex; |
| 46 | + hi.key = swap ? tmp.key : hi.key; |
| 47 | + hi.workgroupRelativeIndex = swap ? tmp.workgroupRelativeIndex : hi.workgroupRelativeIndex; |
| 48 | + } |
| 49 | + else |
| 50 | + { |
| 51 | + bool isUpper = (invocationID & partner) != 0u; |
32 | 52 |
|
33 | | - static void mergeStage(uint32_t stage, bool bitonicAscending, uint32_t invocationID, |
34 | | - NBL_REF_ARG(pair<key_t, value_t>) loPair, NBL_REF_ARG(pair<key_t, value_t>) hiPair) |
35 | | - { |
36 | | - comparator_t comp; |
| 53 | + // Select which element to trade and shuffle members individually |
| 54 | + key_t tradingKey = isUpper ? hi.key : lo.key; |
| 55 | + uint32_t tradingIdx = isUpper ? hi.workgroupRelativeIndex : lo.workgroupRelativeIndex; |
37 | 56 |
|
38 | | - [unroll] |
39 | | - for (uint32_t pass = 0; pass <= stage; pass++) |
40 | | - { |
41 | | - const uint32_t stride = 1u << (stage - pass); // Element stride |
42 | | - const uint32_t threadStride = stride >> 1; |
43 | | - if (threadStride == 0) |
44 | | - { |
45 | | - // Local compare and swap for stage 0 |
46 | | - nbl::hlsl::bitonic_sort::compareSwap(bitonicAscending, loPair, hiPair, comp); |
47 | | - } |
48 | | - else |
49 | | - { |
50 | | - // Shuffle from partner using XOR |
51 | | - const key_t pLoKey = glsl::subgroupShuffleXor<key_t>(loPair.first, threadStride); |
52 | | - const value_t pLoVal = glsl::subgroupShuffleXor<value_t>(loPair.second, threadStride); |
53 | | - const key_t pHiKey = glsl::subgroupShuffleXor<key_t>(hiPair.first, threadStride); |
54 | | - const value_t pHiVal = glsl::subgroupShuffleXor<value_t>(hiPair.second, threadStride); |
| 57 | + tradingKey = glsl::subgroupShuffleXor(tradingKey, partner); |
| 58 | + tradingIdx = glsl::subgroupShuffleXor(tradingIdx, partner); |
55 | 59 |
|
56 | | - const pair<key_t, value_t> partnerLoPair = make_pair(pLoKey, pLoVal); |
57 | | - const pair<key_t, value_t> partnerHiPair = make_pair(pHiKey, pHiVal); |
| 60 | + lo.key = isUpper ? lo.key : tradingKey; |
| 61 | + lo.workgroupRelativeIndex = isUpper ? lo.workgroupRelativeIndex : tradingIdx; |
| 62 | + hi.key = isUpper ? tradingKey : hi.key; |
| 63 | + hi.workgroupRelativeIndex = isUpper ? tradingIdx : hi.workgroupRelativeIndex; |
58 | 64 |
|
59 | | - const bool isUpper = bool(invocationID & threadStride); |
60 | | - const bool takeLarger = isUpper == bitonicAscending; |
| 65 | + bool swap = comp(hi.key, lo.key) == bitonicAscending; |
| 66 | + WGType tmp = lo; |
| 67 | + lo.key = swap ? hi.key : lo.key; |
| 68 | + lo.workgroupRelativeIndex = swap ? hi.workgroupRelativeIndex : lo.workgroupRelativeIndex; |
| 69 | + hi.key = swap ? tmp.key : hi.key; |
| 70 | + hi.workgroupRelativeIndex = swap ? tmp.workgroupRelativeIndex : hi.workgroupRelativeIndex; |
| 71 | + } |
| 72 | + } |
| 73 | + } |
61 | 74 |
|
62 | | - nbl::hlsl::bitonic_sort::compareExchangeWithPartner(takeLarger, loPair, partnerLoPair, hiPair, partnerHiPair, comp); |
63 | | - } |
64 | | - } |
65 | | - } |
| 75 | + static void __call(bool ascending, NBL_REF_ARG(WGType) lo, NBL_REF_ARG(WGType) hi) |
| 76 | + { |
| 77 | + uint32_t id = glsl::gl_SubgroupInvocationID(); |
| 78 | + uint32_t log2 = glsl::gl_SubgroupSizeLog2(); |
66 | 79 |
|
67 | | - static void __call(bool ascending, NBL_REF_ARG(pair<key_t, value_t>) loPair, NBL_REF_ARG(pair<key_t, value_t>) hiPair) |
68 | | - { |
69 | | - const uint32_t invocationID = glsl::gl_SubgroupInvocationID(); |
70 | | - const uint32_t subgroupSizeLog2 = glsl::gl_SubgroupSizeLog2(); |
71 | | - [unroll] |
72 | | - for (uint32_t stage = 0; stage <= subgroupSizeLog2; stage++) |
73 | | - { |
74 | | - const bool bitonicAscending = (stage == subgroupSizeLog2) ? ascending : !bool(invocationID & (1u << stage)); |
75 | | - mergeStage(stage, bitonicAscending, invocationID, loPair, hiPair); |
76 | | - } |
77 | | - } |
| 80 | + [unroll] |
| 81 | + for (uint32_t s = 0u; s <= log2; ++s) |
| 82 | + { |
| 83 | + bool dir = (s == log2) ? ascending : ((id & (1u << s)) != 0u); |
| 84 | + mergeStage(s, dir, id, lo, hi); |
| 85 | + } |
| 86 | + } |
78 | 87 | }; |
79 | 88 |
|
80 | | -} |
81 | | -} |
82 | | -} |
| 89 | +} // namespace bitonic_sort |
| 90 | +} // namespace subgroup |
| 91 | +} // namespace hlsl |
| 92 | +} // namespace nbl |
| 93 | + |
83 | 94 | #endif |
0 commit comments