Skip to content

Commit 7ab5e68

Browse files
authored
Merge pull request #963 from CrabExtra/master
Move branch over
2 parents 98d8151 + a5e811b commit 7ab5e68

7 files changed

Lines changed: 630 additions & 58 deletions

File tree

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#ifndef _NBL_BUILTIN_HLSL_BITONIC_SORT_COMMON_INCLUDED_
2+
#define _NBL_BUILTIN_HLSL_BITONIC_SORT_COMMON_INCLUDED_
3+
4+
#include "nbl/builtin/hlsl/cpp_compat.hlsl"
5+
#include "nbl/builtin/hlsl/functional.hlsl"
6+
#include "nbl/builtin/hlsl/subgroup/basic.hlsl"
7+
#include "nbl/builtin/hlsl/glsl_compat/subgroup_shuffle.hlsl"
8+
9+
namespace nbl
10+
{
11+
namespace hlsl
12+
{
13+
namespace bitonic_sort
14+
{
15+
16+
template<typename KeyType, typename ValueType, uint32_t SubgroupSizelog2, typename Comparator>
17+
struct bitonic_sort_config
18+
{
19+
using key_t = KeyType;
20+
using value_t = ValueType;
21+
using comparator_t = Comparator;
22+
static const uint32_t SubgroupSizeLog2 = SubgroupSizelog2;
23+
static const uint32_t SubgroupSize = 1u << SubgroupSizeLog2;
24+
};
25+
26+
template<typename Config, class device_capabilities = void>
27+
struct bitonic_sort;
28+
29+
30+
template<typename sortable_t, uint32_t Log2N, typename Comparator>
31+
struct LocalPasses
32+
{
33+
static const uint32_t N = 1u << Log2N;
34+
void operator()(bool ascending, sortable_t data[N], NBL_CONST_REF_ARG(Comparator) comp);
35+
};
36+
37+
// Specialization for 2 elements (Log2N=1)
38+
template<typename sortable_t, typename Comparator>
39+
struct LocalPasses<sortable_t, 1, Comparator>
40+
{
41+
static const uint32_t N = 2;
42+
43+
void operator()(bool ascending, sortable_t data[N], NBL_CONST_REF_ARG(Comparator) comp)
44+
{
45+
// For ascending: swap if data[1] < data[0] (put smaller first)
46+
// For descending: swap if data[0] < data[1] (put larger first)
47+
const bool needSwap = ascending ? comp(data[1], data[0]) : comp(data[0], data[1]);
48+
49+
if (needSwap)
50+
{
51+
sortable_t temp = data[0];
52+
data[0] = data[1];
53+
data[1] = temp;
54+
}
55+
}
56+
};
57+
58+
59+
} // namespace bitonic_sort
60+
} // namespace hlsl
61+
} // namespace nbl
62+
63+
#endif
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#ifndef _NBL_BUILTIN_HLSL_CONCEPTS_ACCESSORS_BITONIC_SORT_INCLUDED_
2+
#define _NBL_BUILTIN_HLSL_CONCEPTS_ACCESSORS_BITONIC_SORT_INCLUDED_
3+
4+
#include "nbl/builtin/hlsl/concepts/accessors/generic_shared_data.hlsl"
5+
6+
namespace nbl
7+
{
8+
namespace hlsl
9+
{
10+
namespace workgroup
11+
{
12+
namespace bitonic_sort
13+
{
14+
// The SharedMemoryAccessor MUST provide the following methods:
15+
// * void get(uint32_t index, NBL_REF_ARG(uint32_t) value);
16+
// * void set(uint32_t index, in uint32_t value);
17+
// * void workgroupExecutionAndMemoryBarrier();
18+
template<typename T, typename V = uint32_t, typename I = uint32_t>
19+
NBL_BOOL_CONCEPT BitonicSortSharedMemoryAccessor = concepts::accessors::GenericSharedMemoryAccessor<T, V, I>;
20+
21+
// The Accessor MUST provide the following methods:
22+
// * void get(uint32_t index, NBL_REF_ARG(pair<KeyType, ValueType>) value);
23+
// * void set(uint32_t index, in pair<KeyType, ValueType> value);
24+
template<typename T, typename KeyType, typename ValueType, typename I = uint32_t>
25+
NBL_BOOL_CONCEPT BitonicSortAccessor = concepts::accessors::GenericDataAccessor<T, pair<KeyType, ValueType>, I>;
26+
27+
}
28+
}
29+
}
30+
}
31+
#endif

include/nbl/builtin/hlsl/memory_accessor.hlsl

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,6 @@ namespace nbl
2222
{
2323
namespace hlsl
2424
{
25-
26-
// TODO: flesh out and move to `nbl/builtin/hlsl/utility.hlsl`
27-
template<typename T1, typename T2>
28-
struct pair
29-
{
30-
using first_type = T1;
31-
using second_type = T2;
32-
33-
first_type first;
34-
second_type second;
35-
};
36-
3725
namespace accessor_adaptors
3826
{
3927
namespace impl
@@ -227,4 +215,4 @@ struct Offset : impl::OffsetBase<IndexType,_Offset>
227215
}
228216
}
229217
}
230-
#endif
218+
#endif
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
#ifndef _NBL_BUILTIN_HLSL_SUBGROUP_BITONIC_SORT_INCLUDED_
2+
#define _NBL_BUILTIN_HLSL_SUBGROUP_BITONIC_SORT_INCLUDED_
3+
4+
#include "nbl/builtin/hlsl/bitonic_sort/common.hlsl"
5+
#include "nbl/builtin/hlsl/subgroup/basic.hlsl"
6+
#include "nbl/builtin/hlsl/glsl_compat/subgroup_shuffle.hlsl"
7+
8+
namespace nbl
9+
{
10+
namespace hlsl
11+
{
12+
namespace subgroup
13+
{
14+
namespace bitonic_sort
15+
{
16+
using namespace nbl::hlsl::bitonic_sort;
17+
18+
template<typename KeyType, typename Comparator, class device_capabilities = void>
19+
struct bitonic_sort_wgtype
20+
{
21+
using WGType = WorkgroupType<KeyType>;
22+
using key_t = KeyType;
23+
using comparator_t = Comparator;
24+
25+
static void mergeStage(
26+
uint32_t stage,
27+
bool bitonicAscending,
28+
uint32_t invocationID,
29+
NBL_REF_ARG(WGType) lo,
30+
NBL_REF_ARG(WGType) hi)
31+
{
32+
comparator_t comp;
33+
34+
[unroll]
35+
for (uint32_t pass = 0u; pass <= stage; ++pass)
36+
{
37+
uint32_t stride = 1u << (stage - pass);
38+
uint32_t partner = stride >> 1;
39+
40+
if (partner == 0u)
41+
{
42+
bool swap = comp(hi.key, lo.key) == bitonicAscending;
43+
WGType tmp = lo;
44+
lo.key = swap ? hi.key : lo.key;
45+
lo.workgroupRelativeIndex = swap ? hi.workgroupRelativeIndex : lo.workgroupRelativeIndex;
46+
hi.key = swap ? tmp.key : hi.key;
47+
hi.workgroupRelativeIndex = swap ? tmp.workgroupRelativeIndex : hi.workgroupRelativeIndex;
48+
}
49+
else
50+
{
51+
bool isUpper = (invocationID & partner) != 0u;
52+
53+
// Select which element to trade and shuffle members individually
54+
key_t tradingKey = isUpper ? hi.key : lo.key;
55+
uint32_t tradingIdx = isUpper ? hi.workgroupRelativeIndex : lo.workgroupRelativeIndex;
56+
57+
tradingKey = glsl::subgroupShuffleXor(tradingKey, partner);
58+
tradingIdx = glsl::subgroupShuffleXor(tradingIdx, partner);
59+
60+
lo.key = isUpper ? lo.key : tradingKey;
61+
lo.workgroupRelativeIndex = isUpper ? lo.workgroupRelativeIndex : tradingIdx;
62+
hi.key = isUpper ? tradingKey : hi.key;
63+
hi.workgroupRelativeIndex = isUpper ? tradingIdx : hi.workgroupRelativeIndex;
64+
65+
bool swap = comp(hi.key, lo.key) == bitonicAscending;
66+
WGType tmp = lo;
67+
lo.key = swap ? hi.key : lo.key;
68+
lo.workgroupRelativeIndex = swap ? hi.workgroupRelativeIndex : lo.workgroupRelativeIndex;
69+
hi.key = swap ? tmp.key : hi.key;
70+
hi.workgroupRelativeIndex = swap ? tmp.workgroupRelativeIndex : hi.workgroupRelativeIndex;
71+
}
72+
}
73+
}
74+
75+
static void __call(bool ascending, NBL_REF_ARG(WGType) lo, NBL_REF_ARG(WGType) hi)
76+
{
77+
uint32_t id = glsl::gl_SubgroupInvocationID();
78+
uint32_t log2 = glsl::gl_SubgroupSizeLog2();
79+
80+
[unroll]
81+
for (uint32_t s = 0u; s <= log2; ++s)
82+
{
83+
bool dir = (s == log2) ? ascending : ((id & (1u << s)) != 0u);
84+
mergeStage(s, dir, id, lo, hi);
85+
}
86+
}
87+
};
88+
89+
} // namespace bitonic_sort
90+
} // namespace subgroup
91+
} // namespace hlsl
92+
} // namespace nbl
93+
94+
#endif
Lines changed: 70 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,70 @@
1-
// Copyright (C) 2024 - DevSH Graphics Programming Sp. z O.O.
2-
// This file is part of the "Nabla Engine".
3-
// For conditions of distribution and use, see copyright notice in nabla.h
4-
#ifndef _NBL_BUILTIN_HLSL_UTILITY_INCLUDED_
5-
#define _NBL_BUILTIN_HLSL_UTILITY_INCLUDED_
6-
7-
8-
#include <nbl/builtin/hlsl/type_traits.hlsl>
9-
10-
11-
// for now we only implement declval
12-
namespace nbl
13-
{
14-
namespace hlsl
15-
{
16-
template<typename T>
17-
const static bool always_true = true;
18-
#ifndef __HLSL_VERSION
19-
20-
template<class T>
21-
std::add_rvalue_reference_t<T> declval() noexcept
22-
{
23-
static_assert(false,"Actually calling declval is ill-formed.");
24-
}
25-
26-
#else
27-
28-
namespace experimental
29-
{
30-
31-
template<class T>
32-
T declval() {}
33-
34-
}
35-
36-
#endif
37-
}
38-
}
39-
40-
#endif
1+
// Copyright (C) 2024 - DevSH Graphics Programming Sp. z O.O.
2+
// This file is part of the "Nabla Engine".
3+
// For conditions of distribution and use, see copyright notice in nabla.h
4+
#ifndef _NBL_BUILTIN_HLSL_UTILITY_INCLUDED_
5+
#define _NBL_BUILTIN_HLSL_UTILITY_INCLUDED_
6+
7+
8+
#include <nbl/builtin/hlsl/type_traits.hlsl>
9+
10+
11+
namespace nbl
12+
{
13+
namespace hlsl
14+
{
15+
16+
template<typename T1, typename T2>
17+
struct pair
18+
{
19+
using first_type = T1;
20+
using second_type = T2;
21+
22+
first_type first;
23+
second_type second;
24+
};
25+
26+
template<typename T1, typename T2>
27+
pair<T1, T2> make_pair(T1 f, T2 s)
28+
{
29+
pair<T1, T2> p;
30+
p.first = f;
31+
p.second = s;
32+
return p;
33+
}
34+
35+
template<typename T1, typename T2>
36+
void swap(NBL_REF_ARG(pair<T1, T2>) a, NBL_REF_ARG(pair<T1, T2>) b)
37+
{
38+
T1 temp_first = a.first;
39+
T2 temp_second = a.second;
40+
a.first = b.first;
41+
a.second = b.second;
42+
b.first = temp_first;
43+
b.second = temp_second;
44+
}
45+
46+
template<typename T>
47+
const static bool always_true = true;
48+
#ifndef __HLSL_VERSION
49+
50+
template<class T>
51+
std::add_rvalue_reference_t<T> declval() noexcept
52+
{
53+
static_assert(false,"Actually calling declval is ill-formed.");
54+
}
55+
56+
#else
57+
58+
namespace experimental
59+
{
60+
61+
template<class T>
62+
T declval() {}
63+
64+
}
65+
66+
#endif
67+
}
68+
}
69+
70+
#endif

0 commit comments

Comments
 (0)