Skip to content

Commit a5e811b

Browse files
committed
some changes for debugging
1 parent 589056f commit a5e811b

1 file changed

Lines changed: 40 additions & 39 deletions

File tree

include/nbl/builtin/hlsl/workgroup/bitonic_sort.hlsl

Lines changed: 40 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define _NBL_BUILTIN_HLSL_WORKGROUP_BITONIC_SORT_INCLUDED_
33

44
#include "nbl/builtin/hlsl/bitonic_sort/common.hlsl"
5+
#include "nbl/builtin/hlsl/subgroup/bitonic_sort.hlsl"
56
#include "nbl/builtin/hlsl/workgroup/basic.hlsl"
67
#include "nbl/builtin/hlsl/memory_accessor.hlsl"
78

@@ -78,14 +79,14 @@ inline void inThreadShuffleSortPairs(
7879
WorkgroupType<typename config::key_t> elems[config::E],
7980
NBL_CONST_REF_ARG(Comp) comp)
8081
{
81-
uint32_t partner = stride;
82+
uint32_t partner = stride >> 1;
8283

8384
[unroll]
84-
for (uint32_t i = 0u; i < config::E; i += stride * 2u)
85+
for (uint32_t i = 0u; i < config::E; i += stride)
8586
{
8687
uint32_t j = i + partner;
8788
bool valid = j < config::E;
88-
bool swap = valid && (ascending ? comp(elems[j], elems[i]) : comp(elems[i], elems[j]));
89+
bool swap = valid && (comp(elems[j].key, elems[i].key) == ascending);
8990

9091
WorkgroupType<typename config::key_t> tmp = elems[i];
9192

@@ -104,26 +105,33 @@ inline void subgroupShuffleSortPairs(
104105
uint32_t subgroupInvocationID,
105106
NBL_CONST_REF_ARG(Comp) comp)
106107
{
107-
bool isUpper = (subgroupInvocationID & stride) != 0u;
108+
uint32_t partner = stride >> 1;
109+
bool isUpper = (subgroupInvocationID & partner) != 0u;
108110

109111
[unroll]
110-
for (uint32_t i = 0u; i < config::E; ++i)
112+
for (uint32_t i = 0u; i < config::E; i += 2u)
111113
{
112-
typename config::key_t partnerKey = glsl::subgroupShuffleXor(elems[i].key, stride);
113-
uint32_t partnerIdx = glsl::subgroupShuffleXor(elems[i].workgroupRelativeIndex, stride);
114+
uint32_t j = i + 1u;
114115

115-
WorkgroupType<typename config::key_t> partnerElem;
116-
partnerElem.key = partnerKey;
117-
partnerElem.workgroupRelativeIndex = partnerIdx;
116+
typename config::key_t tradingKey = isUpper ? elems[j].key : elems[i].key;
117+
uint32_t tradingIdx = isUpper ? elems[j].workgroupRelativeIndex : elems[i].workgroupRelativeIndex;
118118

119-
120-
bool keepPartner = (ascending == isUpper) ? comp(partnerElem, elems[i]) : comp(elems[i], partnerElem);
121-
if (keepPartner) {
122-
elems[i] = partnerElem;
123-
}
124-
}
119+
tradingKey = glsl::subgroupShuffleXor(tradingKey, partner);
120+
tradingIdx = glsl::subgroupShuffleXor(tradingIdx, partner);
125121

126-
atomicOpPairs<config>(ascending, elems, comp);
122+
elems[i].key = isUpper ? elems[i].key : tradingKey;
123+
elems[i].workgroupRelativeIndex = isUpper ? elems[i].workgroupRelativeIndex : tradingIdx;
124+
elems[j].key = isUpper ? tradingKey : elems[j].key;
125+
elems[j].workgroupRelativeIndex = isUpper ? tradingIdx : elems[j].workgroupRelativeIndex;
126+
127+
bool swap = comp(elems[j].key, elems[i].key) == ascending;
128+
WorkgroupType<typename config::key_t> tmp = elems[i];
129+
130+
elems[i].key = swap ? elems[j].key : elems[i].key;
131+
elems[i].workgroupRelativeIndex = swap ? elems[j].workgroupRelativeIndex : elems[i].workgroupRelativeIndex;
132+
elems[j].key = swap ? tmp.key : elems[j].key;
133+
elems[j].workgroupRelativeIndex = swap ? tmp.workgroupRelativeIndex : elems[j].workgroupRelativeIndex;
134+
}
127135
}
128136

129137
template<typename config, typename Comp, typename SMem>
@@ -155,36 +163,30 @@ inline void workgroupShuffleSortPairs(
155163
bool active = (tid >= start) && (tid < end);
156164

157165
[unroll]
158-
for (uint32_t i = 0u; i < config::E; ++i)
166+
for (uint32_t i = 0u; i < config::E/2; ++i)
159167
{
160-
if (active)
168+
uint32_t localIdx = isUpper ? (i + config::E/2) : i;
169+
if (active && (localIdx < config::E))
161170
{
162-
k.set(tid*config::E + i, elems[i].key);
163-
idx.set(tid*config::E + i, elems[i].workgroupRelativeIndex);
171+
k.set(tid*config::E + localIdx, elems[localIdx].key);
172+
idx.set(tid*config::E + localIdx, elems[localIdx].workgroupRelativeIndex);
164173
}
165174
}
166175
smem.workgroupExecutionAndMemoryBarrier();
167176

168177
[unroll]
169-
for (uint32_t i = 0u; i < config::E; ++i)
178+
for (uint32_t i = 0u; i < config::E/2; ++i)
170179
{
171-
if (active)
180+
uint32_t localIdx = isUpper ? i : (i + config::E/2);
181+
if (active && (localIdx < config::E))
172182
{
173-
uint32_t myElemIdx = tid * config::E + i;
183+
uint32_t myElemIdx = tid * config::E + (isUpper ? (i + config::E/2) : i);
174184
uint32_t partnerElemIdx = myElemIdx ^ stride;
175185
uint32_t pTid = partnerElemIdx / config::E;
176186
uint32_t partnerLocalIdx = partnerElemIdx % config::E;
177187

178-
WorkgroupType<typename config::key_t> partnerElem;
179-
k.get(pTid*config::E + partnerLocalIdx, partnerElem.key);
180-
idx.get(pTid*config::E + partnerLocalIdx, partnerElem.workgroupRelativeIndex);
181-
182-
bool isUpper = (myElemIdx & stride) != 0u;
183-
184-
bool keepPartner = (ascending == isUpper) ? comp(partnerElem, elems[i]) : comp(elems[i], partnerElem);
185-
if (keepPartner) {
186-
elems[i] = partnerElem;
187-
}
188+
k.get(pTid*config::E + partnerLocalIdx, elems[localIdx].key);
189+
idx.get(pTid*config::E + partnerLocalIdx, elems[localIdx].workgroupRelativeIndex);
188190
}
189191
}
190192
smem.workgroupExecutionAndMemoryBarrier();
@@ -246,8 +248,8 @@ struct BitonicSort
246248
uint32_t idx = tid * config::E + i;
247249
KVPair kvpair;
248250
acc.template get<KVPair>(idx, kvpair);
249-
elems[i].key = kvpair.first; // The key to sort by (random number)
250-
elems[i].workgroupRelativeIndex = kvpair.second; // The original index to track
251+
elems[i].key = kvpair.first;
252+
elems[i].workgroupRelativeIndex = kvpair.second;
251253
}
252254

253255
typename config::comparator_t comp;
@@ -268,13 +270,12 @@ struct BitonicSort
268270
}
269271
}
270272

271-
// Write back sorted (sortedKey, originalIndex) pairs
272273
[unroll]
273274
for (uint32_t i = 0u; i < config::E; ++i)
274275
{
275276
KVPair output;
276-
output.first = elems[i].key; // Sorted key
277-
output.second = elems[i].workgroupRelativeIndex; // Original index
277+
output.first = elems[i].key;
278+
output.second = elems[i].workgroupRelativeIndex;
278279
acc.template set<KVPair>(tid * config::E + i, output);
279280
}
280281
}

0 commit comments

Comments
 (0)