-
Notifications
You must be signed in to change notification settings - Fork 71
Expand file tree
/
Copy pathbitonic_sort.hlsl
More file actions
94 lines (80 loc) · 3.17 KB
/
bitonic_sort.hlsl
File metadata and controls
94 lines (80 loc) · 3.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#ifndef _NBL_BUILTIN_HLSL_SUBGROUP_BITONIC_SORT_INCLUDED_
#define _NBL_BUILTIN_HLSL_SUBGROUP_BITONIC_SORT_INCLUDED_
#include "nbl/builtin/hlsl/bitonic_sort/common.hlsl"
#include "nbl/builtin/hlsl/subgroup/basic.hlsl"
#include "nbl/builtin/hlsl/glsl_compat/subgroup_shuffle.hlsl"
namespace nbl
{
namespace hlsl
{
namespace subgroup
{
namespace bitonic_sort
{
using namespace nbl::hlsl::bitonic_sort;
template<typename KeyType, typename Comparator, class device_capabilities = void>
struct bitonic_sort_wgtype
{
using WGType = WorkgroupType<KeyType>;
using key_t = KeyType;
using comparator_t = Comparator;
static void mergeStage(
uint32_t stage,
bool bitonicAscending,
uint32_t invocationID,
NBL_REF_ARG(WGType) lo,
NBL_REF_ARG(WGType) hi)
{
comparator_t comp;
[unroll]
for (uint32_t pass = 0u; pass <= stage; ++pass)
{
uint32_t stride = 1u << (stage - pass);
uint32_t partner = stride >> 1;
if (partner == 0u)
{
bool swap = comp(hi.key, lo.key) == bitonicAscending;
WGType tmp = lo;
lo.key = swap ? hi.key : lo.key;
lo.workgroupRelativeIndex = swap ? hi.workgroupRelativeIndex : lo.workgroupRelativeIndex;
hi.key = swap ? tmp.key : hi.key;
hi.workgroupRelativeIndex = swap ? tmp.workgroupRelativeIndex : hi.workgroupRelativeIndex;
}
else
{
bool isUpper = (invocationID & partner) != 0u;
// Select which element to trade and shuffle members individually
key_t tradingKey = isUpper ? hi.key : lo.key;
uint32_t tradingIdx = isUpper ? hi.workgroupRelativeIndex : lo.workgroupRelativeIndex;
tradingKey = glsl::subgroupShuffleXor(tradingKey, partner);
tradingIdx = glsl::subgroupShuffleXor(tradingIdx, partner);
lo.key = isUpper ? lo.key : tradingKey;
lo.workgroupRelativeIndex = isUpper ? lo.workgroupRelativeIndex : tradingIdx;
hi.key = isUpper ? tradingKey : hi.key;
hi.workgroupRelativeIndex = isUpper ? tradingIdx : hi.workgroupRelativeIndex;
bool swap = comp(hi.key, lo.key) == bitonicAscending;
WGType tmp = lo;
lo.key = swap ? hi.key : lo.key;
lo.workgroupRelativeIndex = swap ? hi.workgroupRelativeIndex : lo.workgroupRelativeIndex;
hi.key = swap ? tmp.key : hi.key;
hi.workgroupRelativeIndex = swap ? tmp.workgroupRelativeIndex : hi.workgroupRelativeIndex;
}
}
}
static void __call(bool ascending, NBL_REF_ARG(WGType) lo, NBL_REF_ARG(WGType) hi)
{
uint32_t id = glsl::gl_SubgroupInvocationID();
uint32_t log2 = glsl::gl_SubgroupSizeLog2();
[unroll]
for (uint32_t s = 0u; s <= log2; ++s)
{
bool dir = (s == log2) ? ascending : ((id & (1u << s)) != 0u);
mergeStage(s, dir, id, lo, hi);
}
}
};
} // namespace bitonic_sort
} // namespace subgroup
} // namespace hlsl
} // namespace nbl
#endif