|
| 1 | +#ifndef _NBL_BUILTIN_HLSL_SUBGROUP_BITONIC_SORT_INCLUDED_ |
| 2 | +#define _NBL_BUILTIN_HLSL_SUBGROUP_BITONIC_SORT_INCLUDED_ |
| 3 | + |
| 4 | +#include "nbl/builtin/hlsl/bitonic_sort/common.hlsl" |
| 5 | +#include "nbl/builtin/hlsl/subgroup/basic.hlsl" |
| 6 | +#include "nbl/builtin/hlsl/glsl_compat/subgroup_shuffle.hlsl" |
| 7 | + |
| 8 | +namespace nbl |
| 9 | +{ |
| 10 | +namespace hlsl |
| 11 | +{ |
| 12 | +namespace subgroup |
| 13 | +{ |
| 14 | +namespace bitonic_sort |
| 15 | +{ |
| 16 | +using namespace nbl::hlsl::bitonic_sort; |
| 17 | + |
| 18 | +template<typename KeyType, typename Comparator, class device_capabilities = void> |
| 19 | +struct bitonic_sort_wgtype |
| 20 | +{ |
| 21 | + using WGType = WorkgroupType<KeyType>; |
| 22 | + using key_t = KeyType; |
| 23 | + using comparator_t = Comparator; |
| 24 | + |
| 25 | + static void mergeStage( |
| 26 | + uint32_t stage, |
| 27 | + bool bitonicAscending, |
| 28 | + uint32_t invocationID, |
| 29 | + NBL_REF_ARG(WGType) lo, |
| 30 | + NBL_REF_ARG(WGType) hi) |
| 31 | + { |
| 32 | + comparator_t comp; |
| 33 | + |
| 34 | + [unroll] |
| 35 | + for (uint32_t pass = 0u; pass <= stage; ++pass) |
| 36 | + { |
| 37 | + uint32_t stride = 1u << (stage - pass); |
| 38 | + uint32_t partner = stride >> 1; |
| 39 | + |
| 40 | + if (partner == 0u) |
| 41 | + { |
| 42 | + bool swap = comp(hi.key, lo.key) == bitonicAscending; |
| 43 | + WGType tmp = lo; |
| 44 | + lo.key = swap ? hi.key : lo.key; |
| 45 | + lo.workgroupRelativeIndex = swap ? hi.workgroupRelativeIndex : lo.workgroupRelativeIndex; |
| 46 | + hi.key = swap ? tmp.key : hi.key; |
| 47 | + hi.workgroupRelativeIndex = swap ? tmp.workgroupRelativeIndex : hi.workgroupRelativeIndex; |
| 48 | + } |
| 49 | + else |
| 50 | + { |
| 51 | + bool isUpper = (invocationID & partner) != 0u; |
| 52 | + |
| 53 | + // Select which element to trade and shuffle members individually |
| 54 | + key_t tradingKey = isUpper ? hi.key : lo.key; |
| 55 | + uint32_t tradingIdx = isUpper ? hi.workgroupRelativeIndex : lo.workgroupRelativeIndex; |
| 56 | + |
| 57 | + tradingKey = glsl::subgroupShuffleXor(tradingKey, partner); |
| 58 | + tradingIdx = glsl::subgroupShuffleXor(tradingIdx, partner); |
| 59 | + |
| 60 | + lo.key = isUpper ? lo.key : tradingKey; |
| 61 | + lo.workgroupRelativeIndex = isUpper ? lo.workgroupRelativeIndex : tradingIdx; |
| 62 | + hi.key = isUpper ? tradingKey : hi.key; |
| 63 | + hi.workgroupRelativeIndex = isUpper ? tradingIdx : hi.workgroupRelativeIndex; |
| 64 | + |
| 65 | + bool swap = comp(hi.key, lo.key) == bitonicAscending; |
| 66 | + WGType tmp = lo; |
| 67 | + lo.key = swap ? hi.key : lo.key; |
| 68 | + lo.workgroupRelativeIndex = swap ? hi.workgroupRelativeIndex : lo.workgroupRelativeIndex; |
| 69 | + hi.key = swap ? tmp.key : hi.key; |
| 70 | + hi.workgroupRelativeIndex = swap ? tmp.workgroupRelativeIndex : hi.workgroupRelativeIndex; |
| 71 | + } |
| 72 | + } |
| 73 | + } |
| 74 | + |
| 75 | + static void __call(bool ascending, NBL_REF_ARG(WGType) lo, NBL_REF_ARG(WGType) hi) |
| 76 | + { |
| 77 | + uint32_t id = glsl::gl_SubgroupInvocationID(); |
| 78 | + uint32_t log2 = glsl::gl_SubgroupSizeLog2(); |
| 79 | + |
| 80 | + [unroll] |
| 81 | + for (uint32_t s = 0u; s <= log2; ++s) |
| 82 | + { |
| 83 | + bool dir = (s == log2) ? ascending : ((id & (1u << s)) != 0u); |
| 84 | + mergeStage(s, dir, id, lo, hi); |
| 85 | + } |
| 86 | + } |
| 87 | +}; |
| 88 | + |
| 89 | +} // namespace bitonic_sort |
| 90 | +} // namespace subgroup |
| 91 | +} // namespace hlsl |
| 92 | +} // namespace nbl |
| 93 | + |
| 94 | +#endif |
0 commit comments