@@ -24,18 +24,14 @@ namespace fft
2424
2525// ---------------------------------- Utils -----------------------------------------------
2626template<typename SharedMemoryAdaptor, typename Scalar>
27- struct exchangeValues;
28-
29- template<typename SharedMemoryAdaptor>
30- struct exchangeValues<SharedMemoryAdaptor, float16_t>
27+ struct exchangeValues
3128{
32- static void __call (NBL_REF_ARG (complex_t<float16_t >) lo, NBL_REF_ARG (complex_t<float16_t >) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
29+ static void __call (NBL_REF_ARG (complex_t<Scalar >) lo, NBL_REF_ARG (complex_t<Scalar >) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
3330 {
3431 const bool topHalf = bool (threadID & stride);
35- // Pack two halves into a single uint32_t
36- uint32_t toExchange = bit_cast<uint32_t, float16_t2 >(topHalf ? float16_t2 (lo.real (), lo.imag ()) : float16_t2 (hi.real (), hi.imag ()));
37- shuffleXor<SharedMemoryAdaptor, uint32_t>(toExchange, stride, sharedmemAdaptor);
38- float16_t2 exchanged = bit_cast<float16_t2, uint32_t>(toExchange);
32+ // Pack into float vector because ternary operator does not support structs
33+ vector <Scalar, 2 > exchanged = topHalf ? vector <Scalar, 2 >(lo.real (), lo.imag ()) : vector <Scalar, 2 >(hi.real (), hi.imag ());
34+ shuffleXor<SharedMemoryAdaptor, vector <Scalar, 2 > >(exchanged, stride, sharedmemAdaptor);
3935 if (topHalf)
4036 {
4137 lo.real (exchanged.x);
@@ -45,51 +41,7 @@ struct exchangeValues<SharedMemoryAdaptor, float16_t>
4541 {
4642 hi.real (exchanged.x);
4743 lo.imag (exchanged.y);
48- }
49- }
50- };
51-
52- template<typename SharedMemoryAdaptor>
53- struct exchangeValues<SharedMemoryAdaptor, float32_t>
54- {
55- static void __call (NBL_REF_ARG (complex_t<float32_t>) lo, NBL_REF_ARG (complex_t<float32_t>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
56- {
57- const bool topHalf = bool (threadID & stride);
58- // pack into `float32_t2` because ternary operator doesn't support structs
59- float32_t2 exchanged = topHalf ? float32_t2 (lo.real (), lo.imag ()) : float32_t2 (hi.real (), hi.imag ());
60- shuffleXor<SharedMemoryAdaptor, float32_t2>(exchanged, stride, sharedmemAdaptor);
61- if (topHalf)
62- {
63- lo.real (exchanged.x);
64- lo.imag (exchanged.y);
6544 }
66- else
67- {
68- hi.real (exchanged.x);
69- hi.imag (exchanged.y);
70- }
71- }
72- };
73-
74- template<typename SharedMemoryAdaptor>
75- struct exchangeValues<SharedMemoryAdaptor, float64_t>
76- {
77- static void __call (NBL_REF_ARG (complex_t<float64_t>) lo, NBL_REF_ARG (complex_t<float64_t>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
78- {
79- const bool topHalf = bool (threadID & stride);
80- // pack into `float64_t2` because ternary operator doesn't support structs
81- float64_t2 exchanged = topHalf ? float64_t2 (lo.real (), lo.imag ()) : float64_t2 (hi.real (), hi.imag ());
82- shuffleXor<SharedMemoryAdaptor, float64_t2 >(exchanged, stride, sharedmemAdaptor);
83- if (topHalf)
84- {
85- lo.real (exchanged.x);
86- lo.imag (exchanged.y);
87- }
88- else
89- {
90- hi.real (exchanged.x);
91- hi.imag (exchanged.y);
92- }
9345 }
9446};
9547
@@ -170,7 +122,7 @@ uint32_t getNegativeIndex(uint32_t idx)
170122
171123// Util to unpack two values from the packed FFT X + iY - get outputs in the same input arguments, storing x to lo and y to hi
172124template<typename Scalar>
173- void unpack (NBL_CONST_REF_ARG (complex_t<Scalar>) lo, NBL_CONST_REF_ARG (complex_t<Scalar>) hi)
125+ void unpack (NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi)
174126{
175127 complex_t<Scalar> x = (lo + conj (hi)) * Scalar (0.5 );
176128 hi = rotateRight<Scalar>(lo - conj (hi)) * 0.5 ;
0 commit comments