Skip to content
Draft
843 changes: 843 additions & 0 deletions crypto/math/src/fft/bowers_fft_batch.rs

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions crypto/math/src/fft/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
pub mod bit_reversing;
#[cfg(feature = "alloc")]
pub mod bowers_fft;
#[cfg(feature = "alloc")]
pub mod bowers_fft_batch;
pub mod errors;
#[cfg(feature = "alloc")]
pub mod roots_of_unity;
#[cfg(feature = "alloc")]
pub mod two_half_fft;

#[cfg(all(test, feature = "alloc"))]
pub(crate) mod test_helpers;
246 changes: 246 additions & 0 deletions crypto/math/src/fft/two_half_fft.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
//! Cache-blocked, transpose-free batched FFT (port of Plonky3's two-half
//! `Radix2DitParallel::dft_batch`).
//!
//! The flat Bowers DIF streams the whole `n·m` buffer with large strides at the
//! early layers, thrashing cache for large `n`. This kernel keeps every layer
//! cache-resident by interleaving bit-reversals: bit-reverse → first `mid` DIT
//! layers within `2^mid`-row chunks → bit-reverse → remaining layers within
//! `2^(log_n−mid)`-row chunks → bit-reverse. The bit-reversals turn the
//! large-stride butterflies into chunk-local ones — the cache win the flat
//! Bowers misses. Output is natural order, identical to
//! `bowers_fft_batch_row_major` followed by `in_place_bit_reverse_permute_row_major`.
//!
//! Twiddles are precomputed once per size in [`TwoHalfTwiddles`] and reused
//! across calls (the trace LDE invokes this once per direction per domain, and
//! the same domain recurs across tables and rounds).

#[cfg(feature = "alloc")]
use crate::fft::bit_reversing::reverse_index;
#[cfg(feature = "alloc")]
use crate::fft::bowers_fft_batch::in_place_bit_reverse_permute_row_major;
#[cfg(feature = "alloc")]
use crate::fft::errors::FFTError;
#[cfg(feature = "alloc")]
use crate::field::{
element::FieldElement,
traits::{IsFFTField, IsField, IsSubFieldOf},
};
#[cfg(all(feature = "alloc", feature = "parallel"))]
use rayon::prelude::*;

/// In-place bit-reversal permutation of a flat slice (length a power of two).
#[cfg(feature = "alloc")]
fn bit_reverse_vec<F: IsField>(v: &mut [FieldElement<F>]) {
let n = v.len();
for i in 0..n {
let j = reverse_index(i, n as u64);
if j > i {
v.swap(i, j);
}
}
}

/// Precomputed twiddles for a size-`2^log_n` two-half FFT in one direction.
///
/// `tw` is the flat geometric array `[ω⁰, ω¹, …, ω^(n/2−1)]` (`ω` the forward
/// root for the forward transform, its inverse for the inverse transform);
/// `bitrev_tw` is its bit-reversal permutation, used by the second-half layers.
/// Build once and share across calls of the same size and direction.
#[cfg(feature = "alloc")]
pub struct TwoHalfTwiddles<F: IsField> {
log_n: usize,
tw: Vec<FieldElement<F>>,
bitrev_tw: Vec<FieldElement<F>>,
}

#[cfg(feature = "alloc")]
impl<F: IsFFTField> TwoHalfTwiddles<F> {
/// Precompute twiddles for a size-`2^log_n` transform. `inverse = true`
/// selects the (unscaled) inverse transform (uses `ω⁻¹`); the `1/n`
/// normalization is the caller's responsibility.
pub fn new(log_n: usize, inverse: bool) -> Result<Self, FFTError> {
let n = 1usize << log_n;
let half = n / 2;
// `omega` is unused when half == 0 (log_n == 0), so skip the lookup.
let omega = if half == 0 {
FieldElement::<F>::one()
} else {
let fwd = F::get_primitive_root_of_unity(log_n as u64)
.map_err(|_| FFTError::InputError(n))?;
if inverse {
fwd.inv().map_err(|_| FFTError::InputError(n))?
} else {
fwd
}
};

let mut tw: Vec<FieldElement<F>> = Vec::with_capacity(half);
let mut cur = FieldElement::<F>::one();
for _ in 0..half {
tw.push(cur.clone());
cur = &cur * &omega;
}
let mut bitrev_tw = tw.clone();
bit_reverse_vec(&mut bitrev_tw);

Ok(Self {
log_n,
tw,
bitrev_tw,
})
}
}

/// DIT butterfly over two equal-length row-slices, one twiddle for all pairs:
/// `a' = a + tw·b`, `b' = a − tw·b` (element-wise; `tw·b` is the F×E multiply).
#[cfg(feature = "alloc")]
#[inline]
fn dit_butterfly_rows<F, E>(
lo: &mut [FieldElement<E>],
hi: &mut [FieldElement<E>],
tw: &FieldElement<F>,
) where
F: IsSubFieldOf<E>,
E: IsField,
{
for (a, b) in lo.iter_mut().zip(hi.iter_mut()) {
let t = tw * &*b; // F × E → E
let new_a = &*a + &t;
*b = &*a - &t;
*a = new_a;
}
}

/// First-half DIT layer (per-pair twiddle), applied within one cache-resident
/// row-chunk. `tw` is the flat `[ω^0..ω^(n/2−1)]` array; pair `j` of layer
/// `layer` uses `tw[j · 2^(log_n−1−layer)]`.
#[cfg(feature = "alloc")]
fn dit_first_half_layer<F, E>(
chunk: &mut [FieldElement<E>],
m: usize,
layer: usize,
log_n: usize,
tw: &[FieldElement<F>],
) where
F: IsSubFieldOf<E>,
E: IsField,
{
let half = 1usize << layer;
let block_rows = half * 2;
let step = 1usize << (log_n - 1 - layer);
for block in chunk.chunks_mut(block_rows * m) {
let (lows, highs) = block.split_at_mut(half * m);
for j in 0..half {
let twj = &tw[j * step];
dit_butterfly_rows(
&mut lows[j * m..j * m + m],
&mut highs[j * m..j * m + m],
twj,
);
}
}
}

/// Second-half DIT layer (one twiddle per block, bit-reversed twiddle order),
/// applied within one cache-resident row-chunk owned by `thread`.
#[cfg(feature = "alloc")]
fn dit_second_half_layer<F, E>(
chunk: &mut [FieldElement<E>],
m: usize,
layer: usize,
log_n: usize,
mid: usize,
thread: usize,
bitrev_tw: &[FieldElement<F>],
) where
F: IsSubFieldOf<E>,
E: IsField,
{
let half_block = 1usize << (log_n - 1 - layer);
let block_rows = half_block * 2;
let first_block = thread << (layer - mid);
for (b, block) in chunk.chunks_mut(block_rows * m).enumerate() {
let twb = &bitrev_tw[first_block + b];
let (lows, highs) = block.split_at_mut(half_block * m);
dit_butterfly_rows(lows, highs, twb);
}
}

/// Cache-blocked, transpose-free batched FFT. `buf` is `n * num_cols` row-major
/// (`n` rows of `num_cols` consecutive elements); `tw` are the precomputed
/// twiddles for size `n` in the desired direction (forward or inverse).
/// Output is the natural-order DFT (matches `bowers_fft_batch_row_major`
/// followed by `in_place_bit_reverse_permute_row_major`). Inverse transforms
/// are NOT scaled by `1/n` — that is the caller's responsibility (e.g. folded
/// into the coset-weight pass of the LDE).
#[cfg(feature = "alloc")]
pub fn fft_batch_two_half<F, E>(
buf: &mut [FieldElement<E>],
num_cols: usize,
tw: &TwoHalfTwiddles<F>,
) -> Result<(), FFTError>
where
F: IsFFTField + IsSubFieldOf<E>,
E: IsField,
FieldElement<F>: Sync,
FieldElement<E>: Send + Sync,
{
let m = num_cols;
if m == 0 || buf.is_empty() {
return Ok(());
}
let total = buf.len();
if !total.is_multiple_of(m) {
return Err(FFTError::InputError(total));
}
let n = total / m;
if !n.is_power_of_two() {
return Err(FFTError::InputError(n));
}
let log_n = n.trailing_zeros() as usize;
if log_n != tw.log_n {
return Err(FFTError::InputError(n));
}
if log_n == 0 {
return Ok(());
}

let flat_tw = &tw.tw;
let bitrev_tw = &tw.bitrev_tw;
let mid = log_n.div_ceil(2);

// Step 1: bit-reverse rows.
in_place_bit_reverse_permute_row_major(buf, m);

// Step 2: first half — layers 0..mid within 2^mid-row chunks (all identical).
let first_chunk = (1usize << mid) * m;
#[cfg(feature = "parallel")]
let it = buf.par_chunks_mut(first_chunk);
#[cfg(not(feature = "parallel"))]
let it = buf.chunks_mut(first_chunk);
it.for_each(|chunk| {
for layer in 0..mid {
dit_first_half_layer::<F, E>(chunk, m, layer, log_n, flat_tw);
}
});

// Step 3: bit-reverse rows.
in_place_bit_reverse_permute_row_major(buf, m);

// Step 4: second half — layers mid..log_n within 2^(log_n-mid)-row chunks.
let second_chunk = (1usize << (log_n - mid)) * m;
#[cfg(feature = "parallel")]
let it2 = buf.par_chunks_mut(second_chunk).enumerate();
#[cfg(not(feature = "parallel"))]
let it2 = buf.chunks_mut(second_chunk).enumerate();
it2.for_each(|(thread, chunk)| {
for layer in mid..log_n {
dit_second_half_layer::<F, E>(chunk, m, layer, log_n, mid, thread, bitrev_tw);
}
});

// Step 5: final bit-reverse to natural order.
in_place_bit_reverse_permute_row_major(buf, m);

Ok(())
}
93 changes: 93 additions & 0 deletions crypto/math/src/polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::fft::bowers_fft::{LayerTwiddles, bowers_fft_opt_fused, bowers_ifft_op
#[cfg(feature = "parallel")]
use crate::fft::bowers_fft::{bowers_fft_opt_fused_parallel, bowers_ifft_opt_parallel};
use crate::fft::errors::FFTError;
use crate::fft::two_half_fft::{TwoHalfTwiddles, fft_batch_two_half};
use crate::field::traits::{IsFFTField, IsField, IsSubFieldOf};
use alloc::{borrow::ToOwned, vec, vec::Vec};

Expand Down Expand Up @@ -502,6 +503,98 @@ impl<E: IsField> Polynomial<FieldElement<E>> {

Ok(())
}

/// Batched row-major coset LDE expansion.
///
/// `buffer` is the row-major flat layout of `n * num_cols` elements
/// (input trace evaluations on the natural-order domain, all M columns
/// interleaved per row). It is expanded in place to length
/// `n * blowup_factor * num_cols`, also row-major, holding the LDE
/// evaluations on the coset.
///
/// Pipeline mirrors [`coset_lde_full_expand`] cell-for-cell, just with
/// the row-major batched FFT primitives so the M columns share twiddle
/// loads inside each butterfly:
/// 1. batched iFFT (DIT) over rows[..n]
/// 2. scale rows[..n] by coset weights (one weight per row, applied to
/// all M elements of that row)
/// 3. zero-pad rows to `n * blowup_factor`
/// 4. batched forward FFT (DIF)
///
/// `weights` must be `n` base-field elements in natural row order.
/// `inv_twiddles` are the size-`n` inverse two-half twiddles; `fwd_twiddles`
/// the size-`n·blowup_factor` forward ones.
pub fn coset_lde_full_expand_row_major<F: IsFFTField + IsSubFieldOf<E> + Send + Sync>(
buffer: &mut Vec<FieldElement<E>>,
num_cols: usize,
blowup_factor: usize,
weights: &[FieldElement<F>],
inv_twiddles: &TwoHalfTwiddles<F>,
fwd_twiddles: &TwoHalfTwiddles<F>,
) -> Result<(), FFTError>
where
E: Send + Sync,
{
if num_cols == 0 || buffer.is_empty() {
return Ok(());
}
let total = buffer.len();
if !total.is_multiple_of(num_cols) {
return Err(FFTError::InputError(total));
}
let n = total / num_cols;
if !n.is_power_of_two() {
return Err(FFTError::InputError(n));
}
let lde_n = n * blowup_factor;
if (lde_n.trailing_zeros() as u64) > F::TWO_ADICITY {
return Err(FFTError::DomainSizeError(lde_n.trailing_zeros() as usize));
}
if weights.len() < n {
return Err(FFTError::InputError(weights.len()));
}

// 1. iFFT on rows[..n] (cache-blocked two-half; natural→natural, no 1/n
// — the 1/n is folded into the coset-weight pass below). Replaces the
// flat-Bowers iFFT, which cache-thrashes at large n.
let prefix_len = n * num_cols;
fft_batch_two_half::<F, E>(&mut buffer[..prefix_len], num_cols, inv_twiddles)?;

// 2. Scale by coset weights — one weight per row, multiply M elements
// of that row by it. Each row is independent → parallelizable.
#[cfg(feature = "parallel")]
{
use rayon::prelude::{IndexedParallelIterator, ParallelIterator, ParallelSliceMut};
buffer[..prefix_len]
.par_chunks_exact_mut(num_cols)
.enumerate()
.for_each(|(r, row)| {
let w = &weights[r];
for x in row.iter_mut() {
*x = w * &*x;
}
});
}
#[cfg(not(feature = "parallel"))]
{
for r in 0..n {
let w = &weights[r];
let row = &mut buffer[r * num_cols..(r + 1) * num_cols];
for x in row.iter_mut() {
*x = w * &*x;
}
}
}

// 3. Zero-pad rows to lde_n.
buffer.resize(lde_n * num_cols, FieldElement::zero());

// 4. Forward FFT (cache-blocked two-half; natural-order output, replaces
// the flat Bowers fwd-FFT(2n) + bit-reverse — the cache-bound step).
fft_batch_two_half::<F, E>(buffer, num_cols, fwd_twiddles)?;

Ok(())
}
}

#[cfg(test)]
Expand Down
Loading
Loading