Skip to content

Commit b9ec10e

Browse files
committed
fix transpose when matrix N!=M
1 parent 8e7615d commit b9ec10e

2 files changed

Lines changed: 15 additions & 3 deletions

File tree

include/nbl/builtin/hlsl/cpp_compat/impl/intrinsics_impl.hlsl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@ template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(find_lsb_helper, findIL
159159

160160
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(bitReverse_helper, bitReverse, (T), (T), T)
161161
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(dot_helper, dot, (T), (T)(T), typename vector_traits<T>::scalar_type)
162-
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(transpose_helper, transpose, (T), (T), T)
163162
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(length_helper, length, (T), (T), typename vector_traits<T>::scalar_type)
164163
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(normalize_helper, normalize, (T), (T), T)
165164
template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(rsqrt_helper, inverseSqrt, (T), (T), T)
@@ -204,6 +203,17 @@ template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(bitCount_helper, bitCou
204203
#undef ARG
205204
#undef AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER
206205

206+
template<typename Matrix> NBL_PARTIAL_REQ_TOP(concepts::Matrix<Matrix>)
207+
struct transpose_helper<Matrix NBL_PARTIAL_REQ_BOT(concepts::Matrix<Matrix>) >
208+
{
209+
using transposed_t = typename matrix_traits<Matrix>::transposed_type;
210+
211+
static transposed_t __call(NBL_CONST_REF_ARG(Matrix) m)
212+
{
213+
using traits = matrix_traits<Matrix>;
214+
return spirv::transpose<Matrix>(m);
215+
}
216+
};
207217
template<typename UInt64> NBL_PARTIAL_REQ_TOP(is_same_v<UInt64, uint64_t>)
208218
struct find_msb_helper<UInt64 NBL_PARTIAL_REQ_BOT(is_same_v<UInt64, uint64_t>) >
209219
{

include/nbl/builtin/hlsl/spirv_intrinsics/core.hlsl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
#include "spirv/unified1/spirv.hpp"
1111

1212
#include <nbl/builtin/hlsl/vector_utils/vector_traits.hlsl>
13+
#include <nbl/builtin/hlsl/matrix_utils/matrix_traits.hlsl>
1314
#include <nbl/builtin/hlsl/type_traits.hlsl>
1415
#include <nbl/builtin/hlsl/concepts.hlsl>
1516
#include <nbl/builtin/hlsl/concepts/vector.hlsl>
17+
#include <nbl/builtin/hlsl/concepts/matrix.hlsl>
1618

1719
namespace nbl
1820
{
@@ -331,9 +333,9 @@ template<typename Vector NBL_FUNC_REQUIRES(is_vector_v<Vector>)
331333
[[vk::ext_instruction( spv::OpDot )]]
332334
typename vector_traits<Vector>::scalar_type dot(Vector lhs, Vector rhs);
333335

334-
template<typename Matrix>
336+
template<typename Matrix NBL_FUNC_REQUIRES(is_matrix_v<Matrix>)
335337
[[vk::ext_instruction( spv::OpTranspose )]]
336-
Matrix transpose(Matrix mat);
338+
typename matrix_traits<Matrix>::transposed_type transpose(Matrix mat);
337339

338340
template<typename Integral>
339341
[[vk::ext_instruction(spv::OpBitCount)]]

0 commit comments

Comments
 (0)