22#define _NBL_BUILTIN_HLSL_WORKGROUP_SHUFFLE_INCLUDED_
33
44#include "nbl/builtin/hlsl/memory_accessor.hlsl"
5+ #include "nbl/builtin/hlsl/functional.hlsl"
56
67// TODO: Add other shuffles
78
@@ -14,26 +15,87 @@ namespace hlsl
1415namespace workgroup
1516{
1617
18+ // ------------------------------------- Skeletons for implementing other Shuffles --------------------------------
19+
1720template<typename SharedMemoryAdaptor, typename T>
18- struct shuffleXor
21+ struct Shuffle
22+ {
23+ static void __call (NBL_REF_ARG (T) value, uint32_t storeIdx, uint32_t loadIdx, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
24+ {
25+ // TODO: optimization (optional) where we shuffle in the shared memory available (using rounds)
26+ sharedmemAdaptor.template set<T>(storeIdx, value);
27+
28+ // Wait until all writes are done before reading
29+ sharedmemAdaptor.workgroupExecutionAndMemoryBarrier ();
30+
31+ sharedmemAdaptor.template get<T>(loadIdx, value);
32+ }
33+
34+ // By default store to threadID in the workgroup
35+ static void __call (NBL_REF_ARG (T) value, uint32_t loadIdx, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
36+ {
37+ __call (value, uint32_t (SubgroupContiguousIndex ()), loadIdx, sharedmemAdaptor);
38+ }
39+ };
40+
41+ template<class UnOp, typename SharedMemoryAdaptor, typename T>
42+ struct ShuffleUnOp
43+ {
44+ static void __call (NBL_REF_ARG (T) value, uint32_t a, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
45+ {
46+ UnOp unop;
47+ // TODO: optimization (optional) where we shuffle in the shared memory available (using rounds)
48+ sharedmemAdaptor.template set<T>(a, value);
49+
50+ // Wait until all writes are done before reading
51+ sharedmemAdaptor.workgroupExecutionAndMemoryBarrier ();
52+
53+ sharedmemAdaptor.template get<T>(unop (a), value);
54+ }
55+
56+ // By default store to threadID's index and load from unop(threadID)
57+ static void __call (NBL_REF_ARG (T) value, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
58+ {
59+ __call (value, uint32_t (SubgroupContiguousIndex ()), sharedmemAdaptor);
60+ }
61+ };
62+
63+ template<class BinOp, typename SharedMemoryAdaptor, typename T>
64+ struct ShuffleBinOp
1965{
20- static void __call (NBL_REF_ARG (T) value, uint32_t mask , uint32_t threadID , NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
66+ static void __call (NBL_REF_ARG (T) value, uint32_t a , uint32_t b , NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
2167 {
68+ BinOp binop;
2269 // TODO: optimization (optional) where we shuffle in the shared memory available (using rounds)
23- sharedmemAdaptor.template set<T>(threadID , value);
24-
70+ sharedmemAdaptor.template set<T>(a , value);
71+
2572 // Wait until all writes are done before reading
2673 sharedmemAdaptor.workgroupExecutionAndMemoryBarrier ();
27-
28- sharedmemAdaptor.template get<T>(threadID ^ mask , value);
74+
75+ sharedmemAdaptor.template get<T>(binop (a, b) , value);
2976 }
3077
31- static void __call (NBL_REF_ARG (T) value, uint32_t mask, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
78+ // By default first argument of binary op is the thread's ID in the workgroup
79+ static void __call (NBL_REF_ARG (T) value, uint32_t b, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
3280 {
33- __call (value, mask, uint32_t (SubgroupContiguousIndex ()), sharedmemAdaptor);
81+ __call (value, uint32_t (SubgroupContiguousIndex ()), b , sharedmemAdaptor);
3482 }
3583};
3684
85+ // ------------------------------------------ ShuffleXor ---------------------------------------------------------------
86+
87+ template<typename SharedMemoryAdaptor, typename T>
88+ void shuffleXor (NBL_REF_ARG (T) value, uint32_t threadID, uint32_t mask, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
89+ {
90+ return ShuffleBinOp<bit_xor<uint32_t>, SharedMemoryAdaptor, T>::__call (value, threadID, mask, sharedmemAdaptor);
91+ }
92+
93+ template<typename SharedMemoryAdaptor, typename T>
94+ void shuffleXor (NBL_REF_ARG (T) value, uint32_t mask, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
95+ {
96+ return ShuffleBinOp<bit_xor<uint32_t>, SharedMemoryAdaptor, T>::__call (value, mask, sharedmemAdaptor);
97+ }
98+
3799}
38100}
39101}
0 commit comments