Skip to content

Commit c12a724

Browse files
committed
forgotten to commit file
1 parent 37cb817 commit c12a724

1 file changed

Lines changed: 143 additions & 46 deletions

File tree

include/nbl/builtin/hlsl/sampling/cumulative_probability.hlsl

Lines changed: 143 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,38 @@ namespace hlsl
1616
namespace sampling
1717
{
1818

19-
// Discrete sampler using cumulative probability lookup via upper_bound.
19+
// Discrete sampler using cumulative probability lookup.
2020
//
2121
// Samples a discrete index in [0, N) with probability proportional to
2222
// precomputed weights in O(log N) time per sample.
2323
//
24-
// The cumulative probability array stores N-1 entries (the last bucket
25-
// is always 1.0 and need not be stored). Entry i holds the sum of
26-
// probabilities for indices [0, i].
24+
// Three layouts / cache-population strategies, selected by the Mode parameter:
25+
//
26+
// TRACKING (default): N-1 CDF entries, last bucket implicit at 1.0.
27+
// A stateful comparator records the straddling CDF
28+
// values during upper_bound itself.
29+
// YOLO: Same storage. Plain upper_bound followed by two
30+
// re-reads of the adjacent CDF entries (warm cache).
31+
// Lower register footprint, two extra array reads.
32+
// EYTZINGER: Level-order implicit binary tree in 2*P entries
33+
// where P = roundUpPot(N). Leaves at [P, P+N) hold
34+
// the CDF; interior nodes at [1, P) hold split keys.
35+
// Descent reads adjacent memory at each step, so
36+
// every cache line pulled is fully utilised and the
37+
// first log2(subgroupSize) iterations are served by a
38+
// single transaction per subgroup. Build with
39+
// sampling::buildEytzinger<T>().
2740
//
2841
// Satisfies TractableSampler and ResamplableSampler (not BackwardTractableSampler:
2942
// the mapping is discrete).
30-
template<typename T, typename Domain, typename Codomain, typename CumProbAccessor
43+
enum CumulativeProbabilityMode : uint32_t
44+
{
45+
TRACKING = 0u,
46+
YOLO = 1u,
47+
EYTZINGER = 2u
48+
};
49+
50+
template<typename T, typename Domain, typename Codomain, typename CumProbAccessor, CumulativeProbabilityMode Mode = CumulativeProbabilityMode::TRACKING
3151
NBL_PRIMARY_REQUIRES(concepts::accessors::GenericReadAccessor<CumProbAccessor, T, Codomain>)
3252
struct CumulativeProbabilitySampler
3353
{
@@ -44,58 +64,116 @@ struct CumulativeProbabilitySampler
4464
density_type upperBound;
4565
};
4666

67+
// `_size` is the user-facing bucket count N for every mode. TRACKING / YOLO
68+
// expect the accessor to hold N-1 CDF entries; EYTZINGER expects 2*P entries
69+
// in the level-order layout produced by buildEytzinger.
4770
static CumulativeProbabilitySampler create(NBL_CONST_REF_ARG(CumProbAccessor) _cumProbAccessor, uint32_t _size)
4871
{
4972
CumulativeProbabilitySampler retval;
5073
retval.cumProbAccessor = _cumProbAccessor;
5174
retval.storedCount = _size - 1u;
75+
retval.depth = 0u;
76+
NBL_IF_CONSTEXPR(Mode == CumulativeProbabilityMode::EYTZINGER)
77+
{
78+
uint32_t P = 1u;
79+
uint32_t d = 0u;
80+
while (P < _size) { P <<= 1u; ++d; }
81+
retval.depth = d;
82+
}
5283
return retval;
5384
}
5485

5586
// BasicSampler interface
5687
codomain_type generate(const domain_type u) NBL_CONST_MEMBER_FUNC
5788
{
58-
// upper_bound returns first index where cumProb > u
59-
return hlsl::upper_bound(cumProbAccessor, 0u, storedCount, u);
89+
NBL_IF_CONSTEXPR(Mode == CumulativeProbabilityMode::EYTZINGER)
90+
{
91+
const uint32_t leafBase = 1u << depth;
92+
uint32_t index = 1u;
93+
for (uint32_t iter = 0u; iter < depth; ++iter)
94+
{
95+
density_type key;
96+
cumProbAccessor.template get<density_type, uint32_t>(index, key);
97+
index = (index << 1u) | uint32_t(!(u < key));
98+
}
99+
const codomain_type result = codomain_type(index - leafBase);
100+
return result < codomain_type(storedCount) ? result : codomain_type(storedCount);
101+
}
102+
else
103+
{
104+
return hlsl::upper_bound(cumProbAccessor, 0u, storedCount, u);
105+
}
60106
}
61107

62108
// TractableSampler interface
63109
codomain_type generate(const domain_type u, NBL_REF_ARG(cache_type) cache) NBL_CONST_MEMBER_FUNC
64110
{
65-
// #define NBL_CUMPROB_YOLO_READS
66-
#ifdef NBL_CUMPROB_YOLO_READS
67-
// YOLO approach: re-read the array after binary search.
68-
// The accessed elements are adjacent to the found index so the cache is warm.
69-
const codomain_type result = hlsl::upper_bound(cumProbAccessor, 0u, storedCount, u);
70-
cache.oneBefore = density_type(0.0);
71-
if (result)
72-
cumProbAccessor.template get<density_type, codomain_type>(result - 1u, cache.oneBefore);
73-
cache.upperBound = density_type(1.0);
74-
if (result < storedCount)
75-
cumProbAccessor.template get<density_type, codomain_type>(result, cache.upperBound);
76-
#else
77-
// Tracking reads approach: stateful comparator captures CDF values during binary search.
78-
struct CdfComparator
111+
codomain_type result;
112+
NBL_IF_CONSTEXPR(Mode == CumulativeProbabilityMode::EYTZINGER)
79113
{
80-
bool operator()(const density_type value, const density_type rhs)
114+
// Descent visits one interior node per level. Going left tightens
115+
// the upper bound to the current key; going right tightens the
116+
// lower bound. Final index, leafBase is the bucket.
117+
cache.oneBefore = density_type(0.0);
118+
cache.upperBound = density_type(1.0);
119+
const uint32_t leafBase = 1u << depth;
120+
uint32_t index = 1u;
121+
for (uint32_t iter = 0u; iter < depth; ++iter)
81122
{
82-
const bool retval = value < rhs;
83-
if (retval)
84-
upperBound = rhs;
123+
density_type key;
124+
cumProbAccessor.template get<density_type, uint32_t>(index, key);
125+
const bool goRight = !(u < key);
126+
if (goRight)
127+
{
128+
cache.oneBefore = key;
129+
index = (index << 1u) | 1u;
130+
}
85131
else
86-
oneBefore = rhs;
87-
return retval;
132+
{
133+
cache.upperBound = key;
134+
index = (index << 1u);
135+
}
88136
}
89-
90-
density_type oneBefore;
91-
density_type upperBound;
92-
} comp;
93-
comp.oneBefore = density_type(0.0);
94-
comp.upperBound = density_type(1.0);
95-
const codomain_type result = hlsl::upper_bound(cumProbAccessor, 0u, storedCount, u, comp);
96-
cache.oneBefore = comp.oneBefore;
97-
cache.upperBound = comp.upperBound;
98-
#endif
137+
const codomain_type raw = codomain_type(index - leafBase);
138+
result = raw < codomain_type(storedCount) ? raw : codomain_type(storedCount);
139+
}
140+
else NBL_IF_CONSTEXPR(Mode == CumulativeProbabilityMode::YOLO)
141+
{
142+
// Re-read the two adjacent CDF entries after the binary search.
143+
// Both sit on the cache lines the search just touched, so they are warm.
144+
result = hlsl::upper_bound(cumProbAccessor, 0u, storedCount, u);
145+
cache.oneBefore = density_type(0.0);
146+
if (result)
147+
cumProbAccessor.template get<density_type, codomain_type>(result - 1u, cache.oneBefore);
148+
cache.upperBound = density_type(1.0);
149+
if (result < storedCount)
150+
cumProbAccessor.template get<density_type, codomain_type>(result, cache.upperBound);
151+
}
152+
else
153+
{
154+
// TRACKING: stateful comparator captures the CDF values straddling the
155+
// found index during the binary search itself, avoiding the two extra reads.
156+
struct CdfComparator
157+
{
158+
bool operator()(const density_type value, const density_type rhs)
159+
{
160+
const bool retval = value < rhs;
161+
if (retval)
162+
upperBound = rhs;
163+
else
164+
oneBefore = rhs;
165+
return retval;
166+
}
167+
168+
density_type oneBefore;
169+
density_type upperBound;
170+
} comp;
171+
comp.oneBefore = density_type(0.0);
172+
comp.upperBound = density_type(1.0);
173+
result = hlsl::upper_bound(cumProbAccessor, 0u, storedCount, u, comp);
174+
cache.oneBefore = comp.oneBefore;
175+
cache.upperBound = comp.upperBound;
176+
}
99177
return result;
100178
}
101179

@@ -111,16 +189,34 @@ struct CumulativeProbabilitySampler
111189

112190
density_type backwardPdf(const codomain_type v) NBL_CONST_MEMBER_FUNC
113191
{
114-
density_type retval = density_type(1.0);
115-
if (v < storedCount)
116-
cumProbAccessor.template get<density_type, codomain_type>(v, retval);
117-
if (v)
192+
NBL_IF_CONSTEXPR(Mode == CumulativeProbabilityMode::EYTZINGER)
118193
{
119-
density_type prev;
120-
cumProbAccessor.template get<density_type, codomain_type>(v - 1u, prev);
121-
retval -= prev;
194+
// Leaves store the CDF directly; the last real leaf is normalized
195+
// to 1.0 and padded leaves (if any) also hold 1.0.
196+
const uint32_t leafBase = 1u << depth;
197+
density_type retval;
198+
cumProbAccessor.template get<density_type, uint32_t>(leafBase + uint32_t(v), retval);
199+
if (v)
200+
{
201+
density_type prev;
202+
cumProbAccessor.template get<density_type, uint32_t>(leafBase + uint32_t(v) - 1u, prev);
203+
retval -= prev;
204+
}
205+
return retval;
206+
}
207+
else
208+
{
209+
density_type retval = density_type(1.0);
210+
if (v < storedCount)
211+
cumProbAccessor.template get<density_type, codomain_type>(v, retval);
212+
if (v)
213+
{
214+
density_type prev;
215+
cumProbAccessor.template get<density_type, codomain_type>(v - 1u, prev);
216+
retval -= prev;
217+
}
218+
return retval;
122219
}
123-
return retval;
124220
}
125221

126222
weight_type backwardWeight(const codomain_type v) NBL_CONST_MEMBER_FUNC
@@ -129,7 +225,8 @@ struct CumulativeProbabilitySampler
129225
}
130226

131227
CumProbAccessor cumProbAccessor;
132-
uint32_t storedCount;
228+
uint32_t storedCount; // N - 1 (last real bucket index)
229+
uint32_t depth; // EYTZINGER only: ceil(log2(N)), iteration count; leafBase = 1 << depth
133230
};
134231

135232
} // namespace sampling

0 commit comments

Comments
 (0)