@@ -97,6 +97,68 @@ struct exchangeValues<SharedMemoryAdaptor, float64_t>
9797template <typename scalar_t, uint32_t WorkgroupSize>
9898NBL_CONSTEXPR uint32_t SharedMemoryDWORDs = (sizeof (complex_t<scalar_t>) / sizeof (uint32_t)) * WorkgroupSize;
9999
100+
101+ template<uint32_t N, uint32_t H>
102+ enable_if_t<H <= N, uint32_t> bitShiftRightHigher (uint32_t i)
103+ {
104+ // Highest H bits are numbered N-1 through N - H
105+ // N - H is then the middle bit
106+ // Lowest bits numbered from 0 through N - H - 1
107+ uint32_t low = i & ((1 << (N - H)) - 1 );
108+ uint32_t mid = i & (1 << (N - H));
109+ uint32_t high = i & ~((1 << (N - H + 1 )) - 1 );
110+
111+ high >>= 1 ;
112+ mid <<= H - 1 ;
113+
114+ return mid | high | low;
115+ }
116+
117+ template<uint32_t N, uint32_t H>
118+ enable_if_t<H <= N, uint32_t> bitShiftLeftHigher (uint32_t i)
119+ {
120+ // Highest H bits are numbered N-1 through N - H
121+ // N - 1 is then the highest bit, and N - 2 through N - H are the middle bits
122+ // Lowest bits numbered from 0 through N - H - 1
123+ uint32_t low = i & ((1 << (N - H)) - 1 );
124+ uint32_t mid = i & (~((1 << (N - H)) - 1 ) | ~(1 << (N - 1 )));
125+ uint32_t high = i & (1 << (N - 1 ));
126+
127+ mid <<= 1 ;
128+ high >>= H - 1 ;
129+
130+ return mid | high | low;
131+ }
132+
133+ // For an N-bit number, mirrors it around the Nyquist frequency, which for the range [0, 2^N - 1] is precisely 2^(N - 1)
134+ template<uint32_t N>
135+ uint32_t mirror (uint32_t i)
136+ {
137+ return ((1 << N) - i) & ((1 << N) - 1 )
138+ }
139+
140+ // This function maps the index `idx` in the output array of a Forward FFT to the index `freqIdx` in the DFT such that `DFT[freqIdx] = output[idx]`
141+ // This is because Cooley-Tukey + subgroup operations end up spewing out the outputs in a weird order
142+ template<uint16_t ElementsPerInvocation, uint32_t WorkgroupSize>
143+ uint32_t getFrequencyAt (uint32_t idx)
144+ {
145+ NBL_CONSTEXPR_STATIC_INLINE uint32_t ELEMENTS_PER_INVOCATION_LOG_2 = uint32_t (mpl::log2<ElementsPerInvocation>::value);
146+ NBL_CONSTEXPR_STATIC_INLINE uint32_t FFT_SIZE_LOG_2 = ELEMENTS_PER_INVOCATION_LOG_2 + uint32_t (mpl::log2<WorkgroupSize>::value);
147+
148+ return mirror <FFT_SIZE_LOG_2>(bitShiftRightHigher<FFT_SIZE_LOG_2, FFT_SIZE_LOG_2 - ELEMENTS_PER_INVOCATION_LOG_2 + 1 >(glsl::bitfieldReverse<uint32_t>(idx) >> (32 - FFT_SIZE_LOG_2)));
149+ }
150+
151+ // This function maps the index `freqIdx` in the DFT to the index `idx` in the output array of a Forward FFT such that `DFT[freqIdx] = output[idx]`
152+ // It is essentially the inverse of `getFrequencyAt`
153+ template<uint16_t ElementsPerInvocation, uint32_t WorkgroupSize>
154+ uint32_t getOutputAt (uint32_t freqIdx)
155+ {
156+ NBL_CONSTEXPR_STATIC_INLINE uint32_t ELEMENTS_PER_INVOCATION_LOG_2 = uint32_t (mpl::log2<ElementsPerInvocation>::value);
157+ NBL_CONSTEXPR_STATIC_INLINE uint32_t FFT_SIZE_LOG_2 = ELEMENTS_PER_INVOCATION_LOG_2 + uint32_t (mpl::log2<WorkgroupSize>::value);
158+
159+ return glsl::bitfieldReverse<uint32_t>(bitShiftLeftHigher<FFT_SIZE_LOG_2, FFT_SIZE_LOG_2 - ELEMENTS_PER_INVOCATION_LOG_2 + 1 >(mirror <FFT_SIZE_LOG_2>(freqIdx))) >> (32 - FFT_SIZE_LOG_2);
160+ }
161+
100162} //namespace fft
101163
102164// ----------------------------------- End Utils -----------------------------------------------
0 commit comments