22#define _NBL_BUILTIN_HLSL_WORKGROUP_BITONIC_SORT_INCLUDED_
33
44#include "nbl/builtin/hlsl/bitonic_sort/common.hlsl"
5+ #include "nbl/builtin/hlsl/subgroup/bitonic_sort.hlsl"
56#include "nbl/builtin/hlsl/workgroup/basic.hlsl"
67#include "nbl/builtin/hlsl/memory_accessor.hlsl"
78
@@ -78,14 +79,14 @@ inline void inThreadShuffleSortPairs(
7879 WorkgroupType<typename config::key_t> elems[config::E],
7980 NBL_CONST_REF_ARG (Comp) comp)
8081{
81- uint32_t partner = stride;
82+ uint32_t partner = stride >> 1 ;
8283
8384 [unroll]
84- for (uint32_t i = 0u; i < config::E; i += stride * 2u )
85+ for (uint32_t i = 0u; i < config::E; i += stride)
8586 {
8687 uint32_t j = i + partner;
8788 bool valid = j < config::E;
88- bool swap = valid && (ascending ? comp (elems[j], elems[i]) : comp (elems[i], elems[j]) );
89+ bool swap = valid && (comp (elems[j].key , elems[i].key) == ascending );
8990
9091 WorkgroupType<typename config::key_t> tmp = elems[i];
9192
@@ -104,26 +105,33 @@ inline void subgroupShuffleSortPairs(
104105 uint32_t subgroupInvocationID,
105106 NBL_CONST_REF_ARG (Comp) comp)
106107{
107- bool isUpper = (subgroupInvocationID & stride) != 0u;
108+ uint32_t partner = stride >> 1 ;
109+ bool isUpper = (subgroupInvocationID & partner) != 0u;
108110
109111 [unroll]
110- for (uint32_t i = 0u; i < config::E; ++i )
112+ for (uint32_t i = 0u; i < config::E; i += 2u )
111113 {
112- typename config::key_t partnerKey = glsl::subgroupShuffleXor (elems[i].key, stride);
113- uint32_t partnerIdx = glsl::subgroupShuffleXor (elems[i].workgroupRelativeIndex, stride);
114+ uint32_t j = i + 1u;
114115
115- WorkgroupType<typename config::key_t> partnerElem;
116- partnerElem.key = partnerKey;
117- partnerElem.workgroupRelativeIndex = partnerIdx;
116+ typename config::key_t tradingKey = isUpper ? elems[j].key : elems[i].key;
117+ uint32_t tradingIdx = isUpper ? elems[j].workgroupRelativeIndex : elems[i].workgroupRelativeIndex;
118118
119-
120- bool keepPartner = (ascending == isUpper) ? comp (partnerElem, elems[i]) : comp (elems[i], partnerElem);
121- if (keepPartner) {
122- elems[i] = partnerElem;
123- }
124- }
119+ tradingKey = glsl::subgroupShuffleXor (tradingKey, partner);
120+ tradingIdx = glsl::subgroupShuffleXor (tradingIdx, partner);
125121
126- atomicOpPairs<config>(ascending, elems, comp);
122+ elems[i].key = isUpper ? elems[i].key : tradingKey;
123+ elems[i].workgroupRelativeIndex = isUpper ? elems[i].workgroupRelativeIndex : tradingIdx;
124+ elems[j].key = isUpper ? tradingKey : elems[j].key;
125+ elems[j].workgroupRelativeIndex = isUpper ? tradingIdx : elems[j].workgroupRelativeIndex;
126+
127+ bool swap = comp (elems[j].key, elems[i].key) == ascending;
128+ WorkgroupType<typename config::key_t> tmp = elems[i];
129+
130+ elems[i].key = swap ? elems[j].key : elems[i].key;
131+ elems[i].workgroupRelativeIndex = swap ? elems[j].workgroupRelativeIndex : elems[i].workgroupRelativeIndex;
132+ elems[j].key = swap ? tmp.key : elems[j].key;
133+ elems[j].workgroupRelativeIndex = swap ? tmp.workgroupRelativeIndex : elems[j].workgroupRelativeIndex;
134+ }
127135}
128136
129137template<typename config, typename Comp, typename SMem>
@@ -155,36 +163,30 @@ inline void workgroupShuffleSortPairs(
155163 bool active = (tid >= start) && (tid < end);
156164
157165 [unroll]
158- for (uint32_t i = 0u; i < config::E; ++i)
166+ for (uint32_t i = 0u; i < config::E/ 2 ; ++i)
159167 {
160- if (active)
168+ uint32_t localIdx = isUpper ? (i + config::E/2 ) : i;
169+ if (active && (localIdx < config::E))
161170 {
162- k.set (tid*config::E + i , elems[i ].key);
163- idx.set (tid*config::E + i , elems[i ].workgroupRelativeIndex);
171+ k.set (tid*config::E + localIdx , elems[localIdx ].key);
172+ idx.set (tid*config::E + localIdx , elems[localIdx ].workgroupRelativeIndex);
164173 }
165174 }
166175 smem.workgroupExecutionAndMemoryBarrier ();
167176
168177 [unroll]
169- for (uint32_t i = 0u; i < config::E; ++i)
178+ for (uint32_t i = 0u; i < config::E/ 2 ; ++i)
170179 {
171- if (active)
180+ uint32_t localIdx = isUpper ? i : (i + config::E/2 );
181+ if (active && (localIdx < config::E))
172182 {
173- uint32_t myElemIdx = tid * config::E + i ;
183+ uint32_t myElemIdx = tid * config::E + (isUpper ? (i + config::E/ 2 ) : i) ;
174184 uint32_t partnerElemIdx = myElemIdx ^ stride;
175185 uint32_t pTid = partnerElemIdx / config::E;
176186 uint32_t partnerLocalIdx = partnerElemIdx % config::E;
177187
178- WorkgroupType<typename config::key_t> partnerElem;
179- k.get (pTid*config::E + partnerLocalIdx, partnerElem.key);
180- idx.get (pTid*config::E + partnerLocalIdx, partnerElem.workgroupRelativeIndex);
181-
182- bool isUpper = (myElemIdx & stride) != 0u;
183-
184- bool keepPartner = (ascending == isUpper) ? comp (partnerElem, elems[i]) : comp (elems[i], partnerElem);
185- if (keepPartner) {
186- elems[i] = partnerElem;
187- }
188+ k.get (pTid*config::E + partnerLocalIdx, elems[localIdx].key);
189+ idx.get (pTid*config::E + partnerLocalIdx, elems[localIdx].workgroupRelativeIndex);
188190 }
189191 }
190192 smem.workgroupExecutionAndMemoryBarrier ();
@@ -246,8 +248,8 @@ struct BitonicSort
246248 uint32_t idx = tid * config::E + i;
247249 KVPair kvpair;
248250 acc.template get<KVPair>(idx, kvpair);
249- elems[i].key = kvpair.first; // The key to sort by (random number)
250- elems[i].workgroupRelativeIndex = kvpair.second; // The original index to track
251+ elems[i].key = kvpair.first;
252+ elems[i].workgroupRelativeIndex = kvpair.second;
251253 }
252254
253255 typename config::comparator_t comp;
@@ -268,13 +270,12 @@ struct BitonicSort
268270 }
269271 }
270272
271- // Write back sorted (sortedKey, originalIndex) pairs
272273 [unroll]
273274 for (uint32_t i = 0u; i < config::E; ++i)
274275 {
275276 KVPair output;
276- output.first = elems[i].key; // Sorted key
277- output.second = elems[i].workgroupRelativeIndex; // Original index
277+ output.first = elems[i].key;
278+ output.second = elems[i].workgroupRelativeIndex;
278279 acc.template set<KVPair>(tid * config::E + i, output);
279280 }
280281 }
0 commit comments