diff --git a/Cargo.lock b/Cargo.lock index 56f65fcf5..0936ed4ad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3276,6 +3276,8 @@ dependencies = [ "math", "math-cuda", "memmap2", + "rand 0.8.5", + "rand_chacha 0.3.1", "rayon", "serde", "serde-wasm-bindgen", diff --git a/crypto/math-cuda/build.rs b/crypto/math-cuda/build.rs index d2e49947e..b2f61f9a2 100644 --- a/crypto/math-cuda/build.rs +++ b/crypto/math-cuda/build.rs @@ -114,4 +114,5 @@ fn main() { compile_ptx("barycentric.cu", "barycentric.ptx", have_nvcc); compile_ptx("deep.cu", "deep.ptx", have_nvcc); compile_ptx("fri.cu", "fri.ptx", have_nvcc); + compile_ptx("inverse.cu", "inverse.ptx", have_nvcc); } diff --git a/crypto/math-cuda/kernels/inverse.cu b/crypto/math-cuda/kernels/inverse.cu new file mode 100644 index 000000000..96a4364ad --- /dev/null +++ b/crypto/math-cuda/kernels/inverse.cu @@ -0,0 +1,313 @@ +// Parallel Montgomery batch inverse over ext3 elements. +// +// Algorithm: given a[0..N-1] all non-zero, compute a^{-1}[0..N-1] using +// prefix[i] = a[0] * a[1] * ... * a[i] (inclusive forward scan) +// suffix[i] = a[i] * a[i+1] * ... * a[N-1] (inclusive backward scan) +// total = prefix[N-1] = suffix[0] +// inv_total = 1 / total (one Fermat inversion on host) +// a^{-1}[i] = prefix[i-1] * inv_total * suffix[i+1] (boundaries use identity) +// +// Each scan is a multi-block 3-phase Hillis-Steele scan in shared memory: +// Phase 1: each block does an inclusive scan over its 256 elements and +// writes its block sum to a per-block totals array. +// Phase 2: recursively scan the block totals (host re-launches this same +// kernel set; recursion depth = ceil(log_256(N))). +// Phase 3: each block reads its offset (the inclusive prefix of all +// preceding block sums) and multiplies it into every element. +// +// Forward and backward kernels are mirrors of each other. +// +// Buffer layouts: all ext3 buffers are interleaved [a0,b0,c0, a1,b1,c1, ...] +// with one u64 per coordinate. `BLOCK_SIZE = 256` ext3 elements per block +// uses 6 KB of shared memory, well under the per-SM limit on Ada/Blackwell. + +#include "goldilocks.cuh" +#include "ext3.cuh" + +#define BLOCK_SIZE 256 + +// --------------------------------------------------------------------------- +// 1. compute_denoms_ext3 +// +// If `subtract_x = 0` (R3 OOD convention): denoms[k * n + i] = z[k] - x[i]. +// Matches CPU `barycentric_inv_denoms(z, points)` = 1/(z - points[i]). +// If `subtract_x = 1` (R4 DEEP convention): denoms[k * n + i] = x[i] - z[k]. +// Matches CPU R4 `denoms.push(x_i - z_k)` convention. +// +// Output is ext3-interleaved of length 3 * k_scalars * n. +// +// Launched as grid = ceil(total / BLOCK_SIZE), where total = k_scalars * n. +// Each thread builds one denom. +// --------------------------------------------------------------------------- +extern "C" __global__ void compute_denoms_ext3( + const uint64_t *x_base, // n u64 + const uint64_t *z_scalars, // 3 * k_scalars u64 + uint64_t n, + uint64_t k_scalars, + uint64_t subtract_x, // 0: z - x; 1: x - z + uint64_t *denoms_out // 3 * k_scalars * n u64 +) { + uint64_t flat = (uint64_t)blockIdx.x * BLOCK_SIZE + threadIdx.x; + uint64_t total = k_scalars * n; + if (flat >= total) return; + + uint64_t k = flat / n; + uint64_t i = flat - k * n; + + uint64_t x_i = x_base[i]; + ext3::Fe3 z = { + z_scalars[k * 3 + 0], + z_scalars[k * 3 + 1], + z_scalars[k * 3 + 2], + }; + ext3::Fe3 d; + if (subtract_x == 0) { + // z - x: lift x to (x, 0, 0), subtract from z. + d.a = goldilocks::sub(z.a, x_i); + d.b = z.b; + d.c = z.c; + } else { + // x - z: lift x to (x, 0, 0), subtract z. + d.a = goldilocks::sub(x_i, z.a); + d.b = goldilocks::neg(z.b); + d.c = goldilocks::neg(z.c); + } + + denoms_out[flat * 3 + 0] = d.a; + denoms_out[flat * 3 + 1] = d.b; + denoms_out[flat * 3 + 2] = d.c; +} + +// --------------------------------------------------------------------------- +// 2. block_inclusive_scan_fwd_ext3 +// +// Per-block forward Hillis-Steele inclusive scan with multiplication. Writes +// scan_out[gid] = product of input[block_start..=gid] and block_totals[bid] = +// the product over the entire block. +// +// Threads handle out-of-range positions by loading the identity element (1), +// so a partial last block still produces a correct scan. +// --------------------------------------------------------------------------- +extern "C" __global__ void block_inclusive_scan_fwd_ext3( + const uint64_t *input, // 3 * n u64 + uint64_t n, + uint64_t *scan_out, // 3 * n u64 + uint64_t *block_totals // 3 * K u64, K = ceil(n / BLOCK_SIZE) +) { + __shared__ ext3::Fe3 shmem[BLOCK_SIZE]; + uint64_t tid = threadIdx.x; + uint64_t gid = (uint64_t)blockIdx.x * BLOCK_SIZE + tid; + + // Load input or identity. + if (gid < n) { + shmem[tid].a = input[gid * 3 + 0]; + shmem[tid].b = input[gid * 3 + 1]; + shmem[tid].c = input[gid * 3 + 2]; + } else { + shmem[tid] = ext3::one(); + } + __syncthreads(); + + // Hillis-Steele inclusive scan: 8 doubling levels for BLOCK_SIZE = 256. + for (uint32_t offset = 1; offset < BLOCK_SIZE; offset <<= 1) { + ext3::Fe3 prev = (tid >= offset) ? shmem[tid - offset] : ext3::one(); + __syncthreads(); + if (tid >= offset) { + shmem[tid] = ext3::mul(prev, shmem[tid]); + } + __syncthreads(); + } + + // Write per-element scan result. + if (gid < n) { + scan_out[gid * 3 + 0] = shmem[tid].a; + scan_out[gid * 3 + 1] = shmem[tid].b; + scan_out[gid * 3 + 2] = shmem[tid].c; + } + + // Block total = scan value at the last VALID thread of this block. + // The last valid gid in this block is min(block_end - 1, n - 1). + // Computing it explicitly (instead of `tid == 255 || gid == n - 1`) + // ensures EXACTLY ONE thread writes per block — in a partial last + // block the two conditions would otherwise both fire and race. + uint64_t block_end = ((uint64_t)blockIdx.x + 1) * BLOCK_SIZE; + uint64_t last_valid_gid = (block_end - 1 < n - 1) ? (block_end - 1) : (n - 1); + if (gid == last_valid_gid) { + block_totals[(uint64_t)blockIdx.x * 3 + 0] = shmem[tid].a; + block_totals[(uint64_t)blockIdx.x * 3 + 1] = shmem[tid].b; + block_totals[(uint64_t)blockIdx.x * 3 + 2] = shmem[tid].c; + } +} + +// --------------------------------------------------------------------------- +// 3. apply_block_offsets_fwd_ext3 +// +// Phase 3 of the forward scan: each block b > 0 multiplies its per-block +// scan by `block_totals_scanned[b-1]` (the inclusive prefix of preceding +// block totals). Block 0 has no offset, so it returns early. +// --------------------------------------------------------------------------- +extern "C" __global__ void apply_block_offsets_fwd_ext3( + uint64_t *scan_inout, // 3 * n u64 (modified in place) + uint64_t n, + const uint64_t *block_totals_scanned // 3 * K u64, inclusive prefix of phase-1 totals +) { + if (blockIdx.x == 0) return; + uint64_t tid = threadIdx.x; + uint64_t gid = (uint64_t)blockIdx.x * BLOCK_SIZE + tid; + if (gid >= n) return; + + ext3::Fe3 offset = { + block_totals_scanned[(blockIdx.x - 1) * 3 + 0], + block_totals_scanned[(blockIdx.x - 1) * 3 + 1], + block_totals_scanned[(blockIdx.x - 1) * 3 + 2], + }; + ext3::Fe3 val = { + scan_inout[gid * 3 + 0], + scan_inout[gid * 3 + 1], + scan_inout[gid * 3 + 2], + }; + ext3::Fe3 res = ext3::mul(offset, val); + scan_inout[gid * 3 + 0] = res.a; + scan_inout[gid * 3 + 1] = res.b; + scan_inout[gid * 3 + 2] = res.c; +} + +// --------------------------------------------------------------------------- +// 4. block_inclusive_scan_rev_ext3 +// +// Mirror of `block_inclusive_scan_fwd_ext3` for the suffix product: +// suffix[i] = input[i] * input[i+1] * ... * input[n-1] +// +// Block b processes pos_from_end in [b*B, (b+1)*B), where gid = n-1-pos_from_end. +// Inside shmem the order is reversed so a forward Hillis-Steele scan over +// the loaded values produces the suffix scan in the original index space. +// --------------------------------------------------------------------------- +extern "C" __global__ void block_inclusive_scan_rev_ext3( + const uint64_t *input, + uint64_t n, + uint64_t *scan_out, + uint64_t *block_totals +) { + __shared__ ext3::Fe3 shmem[BLOCK_SIZE]; + uint64_t tid = threadIdx.x; + uint64_t pos_from_end = (uint64_t)blockIdx.x * BLOCK_SIZE + tid; + bool valid = pos_from_end < n; + uint64_t gid = valid ? (n - 1 - pos_from_end) : 0; + + if (valid) { + shmem[tid].a = input[gid * 3 + 0]; + shmem[tid].b = input[gid * 3 + 1]; + shmem[tid].c = input[gid * 3 + 2]; + } else { + shmem[tid] = ext3::one(); + } + __syncthreads(); + + for (uint32_t offset = 1; offset < BLOCK_SIZE; offset <<= 1) { + ext3::Fe3 prev = (tid >= offset) ? shmem[tid - offset] : ext3::one(); + __syncthreads(); + if (tid >= offset) { + shmem[tid] = ext3::mul(prev, shmem[tid]); + } + __syncthreads(); + } + + if (valid) { + scan_out[gid * 3 + 0] = shmem[tid].a; + scan_out[gid * 3 + 1] = shmem[tid].b; + scan_out[gid * 3 + 2] = shmem[tid].c; + } + + // Mutually-exclusive last-thread mask (same idea as fwd): the last + // valid pos_from_end in this block is min(block_end - 1, n - 1). + uint64_t block_end_rev = ((uint64_t)blockIdx.x + 1) * BLOCK_SIZE; + uint64_t last_valid_pos = (block_end_rev - 1 < n - 1) ? (block_end_rev - 1) : (n - 1); + if (pos_from_end == last_valid_pos) { + block_totals[(uint64_t)blockIdx.x * 3 + 0] = shmem[tid].a; + block_totals[(uint64_t)blockIdx.x * 3 + 1] = shmem[tid].b; + block_totals[(uint64_t)blockIdx.x * 3 + 2] = shmem[tid].c; + } +} + +// --------------------------------------------------------------------------- +// 5. apply_block_offsets_rev_ext3 +// +// Phase 3 of the suffix scan. Block b > 0 multiplies its per-block scan +// by the inclusive prefix of block totals from blocks [0..b-1] (which, in +// the reverse-block indexing, correspond to the indices LARGER than this +// block's gids). +// --------------------------------------------------------------------------- +extern "C" __global__ void apply_block_offsets_rev_ext3( + uint64_t *scan_inout, + uint64_t n, + const uint64_t *block_totals_scanned +) { + if (blockIdx.x == 0) return; + uint64_t tid = threadIdx.x; + uint64_t pos_from_end = (uint64_t)blockIdx.x * BLOCK_SIZE + tid; + if (pos_from_end >= n) return; + uint64_t gid = n - 1 - pos_from_end; + + ext3::Fe3 offset = { + block_totals_scanned[(blockIdx.x - 1) * 3 + 0], + block_totals_scanned[(blockIdx.x - 1) * 3 + 1], + block_totals_scanned[(blockIdx.x - 1) * 3 + 2], + }; + ext3::Fe3 val = { + scan_inout[gid * 3 + 0], + scan_inout[gid * 3 + 1], + scan_inout[gid * 3 + 2], + }; + ext3::Fe3 res = ext3::mul(offset, val); + scan_inout[gid * 3 + 0] = res.a; + scan_inout[gid * 3 + 1] = res.b; + scan_inout[gid * 3 + 2] = res.c; +} + +// --------------------------------------------------------------------------- +// 6. batch_inverse_combine_ext3 +// +// out[i] = prefix[i-1] * inv_total * suffix[i+1] +// +// Boundaries: prefix[-1] = identity, suffix[n] = identity. +// inv_total = 1 / (prefix[n-1]) = 1 / (suffix[0]); the caller computes it +// on host via Fermat's little theorem (one extension-field inverse per +// batch) and uploads as a 3 * u64 device buffer. +// --------------------------------------------------------------------------- +extern "C" __global__ void batch_inverse_combine_ext3( + const uint64_t *prefix, // 3 * n u64 + const uint64_t *suffix, // 3 * n u64 + const uint64_t *inv_total, // 3 u64 + uint64_t n, + uint64_t *out // 3 * n u64 +) { + uint64_t i = (uint64_t)blockIdx.x * BLOCK_SIZE + threadIdx.x; + if (i >= n) return; + + ext3::Fe3 inv_t = {inv_total[0], inv_total[1], inv_total[2]}; + + ext3::Fe3 p; + if (i == 0) { + p = ext3::one(); + } else { + p.a = prefix[(i - 1) * 3 + 0]; + p.b = prefix[(i - 1) * 3 + 1]; + p.c = prefix[(i - 1) * 3 + 2]; + } + + ext3::Fe3 s; + if (i == n - 1) { + s = ext3::one(); + } else { + s.a = suffix[(i + 1) * 3 + 0]; + s.b = suffix[(i + 1) * 3 + 1]; + s.c = suffix[(i + 1) * 3 + 2]; + } + + ext3::Fe3 tmp = ext3::mul(p, inv_t); + ext3::Fe3 res = ext3::mul(tmp, s); + + out[i * 3 + 0] = res.a; + out[i * 3 + 1] = res.b; + out[i * 3 + 2] = res.c; +} diff --git a/crypto/math-cuda/src/barycentric.rs b/crypto/math-cuda/src/barycentric.rs index b4eb12dfd..f299d1839 100644 --- a/crypto/math-cuda/src/barycentric.rs +++ b/crypto/math-cuda/src/barycentric.rs @@ -7,7 +7,9 @@ //! `(z^N - g^N) * 1/N * 1/g^N` to get the final OOD value. That scaling is //! one ext3 mul per column and stays on host. -use cudarc::driver::{LaunchConfig, PushKernelArg}; +use std::sync::Arc; + +use cudarc::driver::{CudaSlice, CudaStream, LaunchConfig, PushKernelArg}; use crate::Result; use crate::device::backend; @@ -177,6 +179,65 @@ pub fn barycentric_base_on_device( Ok(out) } +/// Same as [`barycentric_base_on_device`] but reads `inv_denoms` AND +/// `coset_points` from device handles (no per-call H2D) and runs on the +/// caller's stream (so the inv_denoms producer and this kernel serialize +/// naturally). +/// +/// `inv_denoms_dev` is the full multi-eval-point buffer from +/// `compute_and_invert_denoms_ext3_dev`. `inv_offset_u64` is the start +/// of this eval point's block (in u64s), so the kernel reads +/// `inv_denoms_dev[inv_offset_u64 .. inv_offset_u64 + 3*n]`. +pub fn barycentric_base_on_device_with_dev_inv_denoms( + stream: &Arc, + main_handle: &GpuLdeBase, + row_stride: usize, + coset_points_dev: &CudaSlice, + inv_denoms_dev: &CudaSlice, + inv_offset_u64: usize, + n: usize, +) -> Result> { + assert!(coset_points_dev.len() >= n); + let inv_end = inv_offset_u64 + .checked_add(3 * n) + .expect("barycentric inv_denoms range overflow"); + assert!(inv_end <= inv_denoms_dev.len()); + let num_cols = main_handle.m; + if num_cols == 0 || n == 0 { + return Ok(vec![0; 3 * num_cols]); + } + let col_stride = main_handle.lde_size; + + let be = backend()?; + let mut out_dev = stream.alloc_zeros::(3 * num_cols)?; + let inv_view = inv_denoms_dev.slice(inv_offset_u64..inv_end); + let points_view = coset_points_dev.slice(0..n); + + let col_stride_u64 = col_stride as u64; + let row_stride_u64 = row_stride as u64; + let n_u64 = n as u64; + let cfg = LaunchConfig { + grid_dim: (num_cols as u32, 1, 1), + block_dim: (BLOCK_DIM, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.barycentric_base_batched_strided) + .arg(main_handle.buf.as_ref()) + .arg(&col_stride_u64) + .arg(&row_stride_u64) + .arg(&points_view) + .arg(&inv_view) + .arg(&n_u64) + .arg(&mut out_dev) + .launch(cfg)?; + } + let out = stream.clone_dtoh(&out_dev)?; + stream.synchronize()?; + Ok(out) +} + /// Ext3 counterpart of [`barycentric_base_on_device`]. Reads the aux LDE /// from the de-interleaved device handle. pub fn barycentric_ext3_on_device( @@ -225,3 +286,54 @@ pub fn barycentric_ext3_on_device( stream.synchronize()?; Ok(out) } + +/// Ext3 counterpart of [`barycentric_base_on_device_with_dev_inv_denoms`]. +pub fn barycentric_ext3_on_device_with_dev_inv_denoms( + stream: &Arc, + aux_handle: &GpuLdeExt3, + row_stride: usize, + coset_points_dev: &CudaSlice, + inv_denoms_dev: &CudaSlice, + inv_offset_u64: usize, + n: usize, +) -> Result> { + assert!(coset_points_dev.len() >= n); + let inv_end = inv_offset_u64 + .checked_add(3 * n) + .expect("barycentric inv_denoms range overflow"); + assert!(inv_end <= inv_denoms_dev.len()); + let num_cols = aux_handle.m; + if num_cols == 0 || n == 0 { + return Ok(vec![0; 3 * num_cols]); + } + let col_stride = aux_handle.lde_size; + + let be = backend()?; + let mut out_dev = stream.alloc_zeros::(3 * num_cols)?; + let inv_view = inv_denoms_dev.slice(inv_offset_u64..inv_end); + let points_view = coset_points_dev.slice(0..n); + + let col_stride_u64 = col_stride as u64; + let row_stride_u64 = row_stride as u64; + let n_u64 = n as u64; + let cfg = LaunchConfig { + grid_dim: (num_cols as u32, 1, 1), + block_dim: (BLOCK_DIM, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.barycentric_ext3_batched_strided) + .arg(aux_handle.buf.as_ref()) + .arg(&col_stride_u64) + .arg(&row_stride_u64) + .arg(&points_view) + .arg(&inv_view) + .arg(&n_u64) + .arg(&mut out_dev) + .launch(cfg)?; + } + let out = stream.clone_dtoh(&out_dev)?; + stream.synchronize()?; + Ok(out) +} diff --git a/crypto/math-cuda/src/deep.rs b/crypto/math-cuda/src/deep.rs index 605132529..581fbc404 100644 --- a/crypto/math-cuda/src/deep.rs +++ b/crypto/math-cuda/src/deep.rs @@ -8,7 +8,9 @@ //! `domain_size * 3` u64s, ext3 interleaved (ready to `transmute` to //! `FieldElement` when the caller promises layout compatibility). -use cudarc::driver::{LaunchConfig, PushKernelArg}; +use std::sync::Arc; + +use cudarc::driver::{CudaSlice, CudaStream, LaunchConfig, PushKernelArg}; use crate::Result; use crate::device::backend; @@ -39,7 +41,10 @@ pub fn deep_composition_ext3( row_stride: usize, domain_size: usize, ) -> Result> { + let be = backend()?; + let stream = be.next_stream(); deep_composition_ext3_impl( + &stream, main_lde, aux_lde, None, @@ -81,7 +86,10 @@ pub fn deep_composition_ext3_with_dev_parts( row_stride: usize, domain_size: usize, ) -> Result> { + let be = backend()?; + let stream = be.next_stream(); deep_composition_ext3_impl( + &stream, main_lde, aux_lde, Some(h_parts_dev), @@ -101,8 +109,137 @@ pub fn deep_composition_ext3_with_dev_parts( ) } +/// Fully device-resident R4 DEEP path: parts LDE and inverse denominators +/// both arrive as device handles, the caller threads its own stream +/// through so the inv_denoms producer +/// (`compute_and_invert_denoms_ext3_dev`) and this kernel run on the same +/// stream (no cross-stream race). H2Ds only the small OOD/gamma scalars. +/// +/// `inv_denoms_dev` is `3 * (1 + num_eval_points) * domain_size` u64s: +/// the first `3 * domain_size` u64s are `inv_h` (H-term denominators), +/// followed by `num_eval_points` blocks of `3 * domain_size` for the +/// trace terms. Same layout `compute_and_invert_denoms_ext3_dev` +/// produces when called with `z_scalars = [z_power, z_shifted[0..]]`. +#[allow(clippy::too_many_arguments)] +pub fn deep_composition_ext3_with_dev_parts_and_inv_denoms( + stream: &Arc, + main_lde: &GpuLdeBase, + aux_lde: Option<&GpuLdeExt3>, + h_parts_dev: &GpuLdeExt3, + inv_denoms_dev: &CudaSlice, + h_ood: &[u64], + trace_ood: &[u64], + gammas_h: &[u64], + gammas_tr: &[u64], + num_parts: usize, + num_main: usize, + num_aux: usize, + num_eval_points: usize, + row_stride: usize, + domain_size: usize, +) -> Result> { + assert_eq!(main_lde.m, num_main); + assert_eq!(h_parts_dev.m, num_parts); + assert_eq!(h_parts_dev.lde_size, main_lde.lde_size); + if let Some(a) = aux_lde { + assert_eq!(a.m, num_aux); + assert_eq!(a.lde_size, main_lde.lde_size); + } else { + assert_eq!(num_aux, 0); + } + assert_eq!(h_ood.len(), num_parts * 3); + let num_total_cols = num_main + num_aux; + assert_eq!(trace_ood.len(), num_total_cols * num_eval_points * 3); + assert_eq!(gammas_h.len(), num_parts * 3); + assert_eq!(gammas_tr.len(), num_total_cols * num_eval_points * 3); + + let ext3_size = domain_size + .checked_mul(3) + .expect("deep composition: domain_size * 3 overflow"); + let expected_inv_denoms = ext3_size + .checked_mul(1 + num_eval_points) + .expect("deep composition: inv_denoms length overflow"); + assert_eq!(inv_denoms_dev.len(), expected_inv_denoms); + + if domain_size > 0 { + let max_row = (domain_size - 1) + .checked_mul(row_stride) + .expect("deep composition: (domain_size - 1) * row_stride overflow"); + assert!( + max_row < main_lde.lde_size, + "deep composition: kernel row {max_row} out of LDE stride {}", + main_lde.lde_size + ); + } + + let be = backend()?; + + // H2D only the small scalars on the caller's stream. + let h_ood_dev = stream.clone_htod(h_ood)?; + let trace_ood_dev = stream.clone_htod(trace_ood)?; + let gammas_h_dev = stream.clone_htod(gammas_h)?; + let gammas_tr_dev = stream.clone_htod(gammas_tr)?; + + // Slice the inv_denoms buffer into the H-term and trace-term views. + let inv_h_view = inv_denoms_dev.slice(0..ext3_size); + let inv_t_view = inv_denoms_dev.slice(ext3_size..expected_inv_denoms); + + // SAFETY: every output slot is written by the kernel. + let mut deep_out = unsafe { stream.alloc::(domain_size * 3) }?; + + let dummy_aux; + let aux_slice = if let Some(a) = aux_lde { + a.buf.as_ref() + } else { + dummy_aux = stream.alloc_zeros::(1)?; + &dummy_aux + }; + + let lde_stride = main_lde.lde_size as u64; + let num_main_u = num_main as u64; + let num_aux_u = num_aux as u64; + let num_parts_u = num_parts as u64; + let num_eval_points_u = num_eval_points as u64; + let row_stride_u = row_stride as u64; + let domain_size_u = domain_size as u64; + + let grid = (domain_size as u32).div_ceil(128); + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.deep_composition_ext3_row) + .arg(main_lde.buf.as_ref()) + .arg(aux_slice) + .arg(h_parts_dev.buf.as_ref()) + .arg(&lde_stride) + .arg(&num_main_u) + .arg(&num_aux_u) + .arg(&num_parts_u) + .arg(&num_eval_points_u) + .arg(&row_stride_u) + .arg(&domain_size_u) + .arg(&h_ood_dev) + .arg(&trace_ood_dev) + .arg(&gammas_h_dev) + .arg(&gammas_tr_dev) + .arg(&inv_h_view) + .arg(&inv_t_view) + .arg(&mut deep_out) + .launch(cfg)?; + } + + let out = stream.clone_dtoh(&deep_out)?; + stream.synchronize()?; + Ok(out) +} + #[allow(clippy::too_many_arguments)] fn deep_composition_ext3_impl( + stream: &Arc, main_lde: &GpuLdeBase, aux_lde: Option<&GpuLdeExt3>, h_parts_dev: Option<&GpuLdeExt3>, @@ -155,10 +292,7 @@ fn deep_composition_ext3_impl( } let be = backend()?; - let stream = be.next_stream(); - // H2D only the scalar arrays. h_parts comes from a device handle - // when available. let h_ood_dev = stream.clone_htod(h_ood)?; let trace_ood_dev = stream.clone_htod(trace_ood)?; let gammas_h_dev = stream.clone_htod(gammas_h)?; @@ -166,15 +300,13 @@ fn deep_composition_ext3_impl( let inv_h_dev = stream.clone_htod(inv_h)?; let inv_t_dev = stream.clone_htod(inv_t)?; - // Keep the owned H2D of h_lde alive until kernel completes. Only - // populated in the host-parts path. let h_lde_host_dev; + let dummy_aux; // SAFETY: the deep_composition kernel writes every output slot before // any read, so uninitialised contents are never observed. let mut deep_out = unsafe { stream.alloc::(domain_size * 3) }?; - let dummy_aux; let aux_slice = if let Some(a) = aux_lde { a.buf.as_ref() } else { diff --git a/crypto/math-cuda/src/device.rs b/crypto/math-cuda/src/device.rs index 3c98de395..17e2f9f82 100644 --- a/crypto/math-cuda/src/device.rs +++ b/crypto/math-cuda/src/device.rs @@ -96,6 +96,7 @@ const KECCAK_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/keccak.ptx")); const BARY_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/barycentric.ptx")); const DEEP_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/deep.ptx")); const FRI_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/fri.ptx")); +const INVERSE_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/inverse.ptx")); /// Number of CUDA streams in the pool. Larger pools let many rayon-parallel /// callers overlap on the GPU without serializing on stream ownership. The @@ -160,6 +161,14 @@ pub struct Backend { pub fri_fold_ext3: CudaFunction, pub fri_update_twiddles: CudaFunction, + // inverse.ptx + pub compute_denoms_ext3: CudaFunction, + pub block_inclusive_scan_fwd_ext3: CudaFunction, + pub apply_block_offsets_fwd_ext3: CudaFunction, + pub block_inclusive_scan_rev_ext3: CudaFunction, + pub apply_block_offsets_rev_ext3: CudaFunction, + pub batch_inverse_combine_ext3: CudaFunction, + // Twiddle caches keyed by log_n. fwd_twiddles: Mutex>>>>, inv_twiddles: Mutex>>>>, @@ -180,6 +189,7 @@ impl Backend { let bary = ctx.load_module(Ptx::from_src(BARY_PTX))?; let deep = ctx.load_module(Ptx::from_src(DEEP_PTX))?; let fri = ctx.load_module(Ptx::from_src(FRI_PTX))?; + let inverse = ctx.load_module(Ptx::from_src(INVERSE_PTX))?; let mut streams = Vec::with_capacity(STREAM_POOL_SIZE); for _ in 0..STREAM_POOL_SIZE { @@ -241,6 +251,14 @@ impl Backend { deep_composition_ext3_row: deep.load_function("deep_composition_ext3_row")?, fri_fold_ext3: fri.load_function("fri_fold_ext3")?, fri_update_twiddles: fri.load_function("fri_update_twiddles")?, + compute_denoms_ext3: inverse.load_function("compute_denoms_ext3")?, + block_inclusive_scan_fwd_ext3: inverse + .load_function("block_inclusive_scan_fwd_ext3")?, + apply_block_offsets_fwd_ext3: inverse.load_function("apply_block_offsets_fwd_ext3")?, + block_inclusive_scan_rev_ext3: inverse + .load_function("block_inclusive_scan_rev_ext3")?, + apply_block_offsets_rev_ext3: inverse.load_function("apply_block_offsets_rev_ext3")?, + batch_inverse_combine_ext3: inverse.load_function("batch_inverse_combine_ext3")?, fwd_twiddles: Mutex::new(vec![None; max_log]), inv_twiddles: Mutex::new(vec![None; max_log]), ctx, diff --git a/crypto/math-cuda/src/inverse.rs b/crypto/math-cuda/src/inverse.rs new file mode 100644 index 000000000..2d8b09c08 --- /dev/null +++ b/crypto/math-cuda/src/inverse.rs @@ -0,0 +1,411 @@ +//! Parallel Montgomery batch inverse on the GPU for ext3 elements. +//! +//! The kernels live in `kernels/inverse.cu` and implement a multi-block +//! 3-phase Hillis-Steele scan: each block scans its 256 elements in shmem +//! and emits a block total; the block totals are scanned recursively (the +//! same kernels applied to a smaller array); a final pass multiplies each +//! element by the cumulative offset of preceding blocks. +//! +//! Two public entry points: +//! - `batch_inverse_ext3`: host -> host (parity-test path). +//! - `batch_inverse_ext3_dev`: device -> device, returns a `CudaSlice` +//! handle the caller feeds into the next kernel without a D2H+H2D. +//! +//! Plus the fused convenience `compute_and_invert_denoms_ext3_dev` for the +//! R3 OOD and R4 DEEP denominator pipelines. + +use std::sync::Arc; + +use cudarc::driver::{CudaSlice, CudaStream, LaunchConfig, PushKernelArg}; + +use crate::Result; +use crate::device::backend; + +const BLOCK_SIZE: u32 = 256; + +/// Test-only fault injection. When the `test-faults` feature is on, setting +/// this to a finite value forces the next `compute_and_invert_denoms_ext3_dev` +/// call to return Err and decrement the counter. Tests use this to exercise +/// the CPU-fallback path in `try_compute_and_invert_inv_denoms_dev`. +#[cfg(feature = "test-faults")] +pub static FAULT_INVERSE_REMAINING_UNTIL_ERR: std::sync::atomic::AtomicI64 = + std::sync::atomic::AtomicI64::new(-1); + +#[cfg(feature = "test-faults")] +fn check_inverse_fault_injection() -> Result<()> { + use std::sync::atomic::Ordering; + let v = FAULT_INVERSE_REMAINING_UNTIL_ERR.load(Ordering::Relaxed); + if v < 0 { + return Ok(()); + } + let new = FAULT_INVERSE_REMAINING_UNTIL_ERR.fetch_sub(1, Ordering::Relaxed); + if new == 0 { + return Err(cudarc::driver::DriverError( + cudarc::driver::sys::CUresult::CUDA_ERROR_UNKNOWN, + )); + } + Ok(()) +} + +/// Host-input batch inverse. Returns a fresh `Vec` of length `3 * n` +/// containing the inverses. Used by the parity-test suite; production +/// callers should prefer `batch_inverse_ext3_dev` to avoid the D2H. +pub fn batch_inverse_ext3(a: &[u64]) -> Result> { + assert!(a.len().is_multiple_of(3)); + let n = a.len() / 3; + if n == 0 { + return Ok(Vec::new()); + } + if n == 1 { + // Below GPU break-even (one element). Invert on host via the math + // crate's `Fp3::inv`. + let inv = invert_ext3_host([a[0], a[1], a[2]])?; + return Ok(inv.to_vec()); + } + + let be = backend()?; + let stream = be.next_stream(); + let input_dev = stream.clone_htod(a)?; + let out_dev = batch_inverse_ext3_dev(&input_dev, n, &stream)?; + let out = stream.clone_dtoh(&out_dev)?; + stream.synchronize()?; + Ok(out) +} + +/// Device-input batch inverse. Allocates and returns a fresh `CudaSlice` +/// of length `3 * n` holding the inverses. Requires `n >= 1`. +/// +/// The caller's `stream` is used for every launch and synchronised at the +/// end (so the returned slice's data is committed before this function +/// returns). +pub fn batch_inverse_ext3_dev( + input: &CudaSlice, + n: usize, + stream: &Arc, +) -> Result> { + assert!(n >= 1, "batch_inverse_ext3_dev requires n >= 1"); + // Runtime guard (not debug_assert): a u32 grid_dim is truncated past + // u32::MAX / BLOCK_SIZE, which would silently launch too few blocks + // and leave a tail uninverted. Reachable on LDE size 2^23+ × multi- + // eval-point R4. Returning Err lets the dispatcher's Err(_) => None + // route the caller to the CPU `inplace_batch_inverse` fallback. + if n > u32::MAX as usize / BLOCK_SIZE as usize { + return Err(cudarc::driver::DriverError( + cudarc::driver::sys::CUresult::CUDA_ERROR_INVALID_VALUE, + )); + } + if n == 1 { + // Single element: D2H, host invert, H2D. Avoids running the + // scan + combine machinery for a degenerate case. + let host_view: Vec = stream.clone_dtoh(&input.slice(0..3))?; + stream.synchronize()?; + let inv = invert_ext3_host([host_view[0], host_view[1], host_view[2]])?; + let mut out = unsafe { stream.alloc::(3) }?; + stream.memcpy_htod(&inv, &mut out)?; + return Ok(out); + } + + let be = backend()?; + + // Prefix and suffix scan scratch buffers; fully overwritten by the + // scan kernels, so `alloc` is safe (no need for `alloc_zeros`). + // SAFETY: the multi-block scan kernels write every output slot. + let mut prefix = unsafe { stream.alloc::(3 * n) }?; + let mut suffix = unsafe { stream.alloc::(3 * n) }?; + + scan_into_fwd(stream, be, input, &mut prefix, n)?; + scan_into_rev(stream, be, input, &mut suffix, n)?; + + // total = prefix[n-1] = suffix[0]. Invert on host (one Fermat per batch). + let last_host: Vec = stream.clone_dtoh(&prefix.slice((n - 1) * 3..n * 3))?; + stream.synchronize()?; + let inv_total = invert_ext3_host([last_host[0], last_host[1], last_host[2]])?; + let mut inv_total_dev = unsafe { stream.alloc::(3) }?; + stream.memcpy_htod(&inv_total, &mut inv_total_dev)?; + + // Combine: out[i] = prefix[i-1] * inv_total * suffix[i+1]. + // SAFETY: the combine kernel writes every slot before any read. + let mut out_dev = unsafe { stream.alloc::(3 * n) }?; + let cfg = LaunchConfig { + grid_dim: ((n as u32).div_ceil(BLOCK_SIZE), 1, 1), + block_dim: (BLOCK_SIZE, 1, 1), + shared_mem_bytes: 0, + }; + let n_u64 = n as u64; + unsafe { + stream + .launch_builder(&be.batch_inverse_combine_ext3) + .arg(&prefix) + .arg(&suffix) + .arg(&inv_total_dev) + .arg(&n_u64) + .arg(&mut out_dev) + .launch(cfg)?; + } + // No terminal `stream.synchronize()`: the caller's downstream consumers + // (e.g. `barycentric_*_on_device_with_dev_inv_denoms`, + // `deep_composition_ext3_with_dev_parts_and_inv_denoms`) run on the + // same stream and thus observe the combine kernel's writes via + // CUDA's per-stream FIFO ordering. + Ok(out_dev) +} + +/// Sign convention for `compute_and_invert_denoms_ext3_dev`. +#[derive(Copy, Clone)] +pub enum DenomSign { + /// `denoms[k*n+i] = z_scalars[k] - x[i]`. Matches CPU + /// `barycentric_inv_denoms(z, points)` (R3 OOD). + ZMinusX, + /// `denoms[k*n+i] = x[i] - z_scalars[k]`. Matches CPU R4 DEEP + /// `denoms.push(x_i - z_k)`. + XMinusZ, +} + +/// Compute `denoms[k*n + i] = sign-dependent (z, x) combination` on +/// device, then batch-invert. Returns a fresh `CudaSlice` of length +/// `3 * k_scalars * n` holding the inverted denominators. Entire pipeline +/// stays on device (no PCIe traffic beyond the small `z_scalars` upload). +pub fn compute_and_invert_denoms_ext3_dev( + x_lde_dev: &CudaSlice, + z_scalars_host: &[u64], + n: usize, + k_scalars: usize, + sign: DenomSign, + stream: &Arc, +) -> Result> { + #[cfg(feature = "test-faults")] + check_inverse_fault_injection()?; + assert_eq!(z_scalars_host.len(), k_scalars * 3); + assert!(n >= 1 && k_scalars >= 1); + + let be = backend()?; + let total = k_scalars + .checked_mul(n) + .expect("compute_and_invert_denoms_ext3_dev: k_scalars * n overflow"); + // See `batch_inverse_ext3_dev` for the rationale: runtime Err, not + // debug_assert, so release builds also route past the silent-truncation + // hazard via the caller's CPU fallback. + if total > u32::MAX as usize / BLOCK_SIZE as usize { + return Err(cudarc::driver::DriverError( + cudarc::driver::sys::CUresult::CUDA_ERROR_INVALID_VALUE, + )); + } + + let z_dev = stream.clone_htod(z_scalars_host)?; + // SAFETY: the compute_denoms_ext3 kernel writes every output slot. + let mut denoms = unsafe { stream.alloc::(3 * total) }?; + let n_u64 = n as u64; + let k_u64 = k_scalars as u64; + let subtract_x_u64: u64 = match sign { + DenomSign::ZMinusX => 0, + DenomSign::XMinusZ => 1, + }; + + let cfg = LaunchConfig { + grid_dim: ((total as u32).div_ceil(BLOCK_SIZE), 1, 1), + block_dim: (BLOCK_SIZE, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.compute_denoms_ext3) + .arg(x_lde_dev) + .arg(&z_dev) + .arg(&n_u64) + .arg(&k_u64) + .arg(&subtract_x_u64) + .arg(&mut denoms) + .launch(cfg)?; + } + + batch_inverse_ext3_dev(&denoms, total, stream) +} + +// ============================================================================= +// Multi-block recursive scan driver +// ============================================================================= + +/// Recursive driver: writes `prefix_out[i] = product of input[0..=i]` for i in +/// 0..n. `input` and `prefix_out` may NOT alias for the top-level call (they +/// alias inside the recursion when scanning block totals in place). +fn scan_into_fwd( + stream: &Arc, + be: &crate::device::Backend, + input: &CudaSlice, + prefix_out: &mut CudaSlice, + n: usize, +) -> Result<()> { + if n == 0 { + return Ok(()); + } + let k = (n as u32).div_ceil(BLOCK_SIZE); + // SAFETY: phase-1 writes every block_totals slot when the kernel emits + // the "last in block" value; partial last block also writes its total. + let mut block_totals = unsafe { stream.alloc::(3 * k as usize) }?; + let n_u64 = n as u64; + + let phase_cfg = LaunchConfig { + grid_dim: (k, 1, 1), + block_dim: (BLOCK_SIZE, 1, 1), + shared_mem_bytes: 0, + }; + + // Phase 1: per-block inclusive scan of `input` into `prefix_out`, + // plus per-block totals into `block_totals`. + unsafe { + stream + .launch_builder(&be.block_inclusive_scan_fwd_ext3) + .arg(input) + .arg(&n_u64) + .arg(&mut *prefix_out) + .arg(&mut block_totals) + .launch(phase_cfg)?; + } + + if k > 1 { + // Phase 2: recursively scan block_totals in place. + scan_inplace_fwd(stream, be, &mut block_totals, k as usize)?; + + // Phase 3: each block reads `block_totals_scanned[blockIdx.x - 1]` + // and multiplies into its in-block scan output. + unsafe { + stream + .launch_builder(&be.apply_block_offsets_fwd_ext3) + .arg(&mut *prefix_out) + .arg(&n_u64) + .arg(&block_totals) + .launch(phase_cfg)?; + } + } + Ok(()) +} + +/// In-place forward scan. Used by the recursion: scanning block totals +/// always reads and writes the same buffer. +fn scan_inplace_fwd( + stream: &Arc, + be: &crate::device::Backend, + buf: &mut CudaSlice, + n: usize, +) -> Result<()> { + if n <= 1 { + return Ok(()); + } + let k = (n as u32).div_ceil(BLOCK_SIZE); + let mut block_totals = unsafe { stream.alloc::(3 * k as usize) }?; + let n_u64 = n as u64; + + let phase_cfg = LaunchConfig { + grid_dim: (k, 1, 1), + block_dim: (BLOCK_SIZE, 1, 1), + shared_mem_bytes: 0, + }; + + // Scratch buffer + memcpy_dtod: cudarc's `launch_builder` chains a + // `&buf` read arg and a `&mut buf` write arg, which the borrow checker + // rejects even though the kernel is safe in place. + let mut scratch = unsafe { stream.alloc::(3 * n) }?; + unsafe { + stream + .launch_builder(&be.block_inclusive_scan_fwd_ext3) + .arg(&*buf) + .arg(&n_u64) + .arg(&mut scratch) + .arg(&mut block_totals) + .launch(phase_cfg)?; + } + // Copy scratch back into buf for the apply_block_offsets pass to read+write. + // SAFETY: identical lengths, both on device. + stream.memcpy_dtod(&scratch, buf)?; + + if k > 1 { + scan_inplace_fwd(stream, be, &mut block_totals, k as usize)?; + unsafe { + stream + .launch_builder(&be.apply_block_offsets_fwd_ext3) + .arg(&mut *buf) + .arg(&n_u64) + .arg(&block_totals) + .launch(phase_cfg)?; + } + } + Ok(()) +} + +/// Mirror of `scan_into_fwd` for the suffix scan. +fn scan_into_rev( + stream: &Arc, + be: &crate::device::Backend, + input: &CudaSlice, + suffix_out: &mut CudaSlice, + n: usize, +) -> Result<()> { + if n == 0 { + return Ok(()); + } + let k = (n as u32).div_ceil(BLOCK_SIZE); + let mut block_totals = unsafe { stream.alloc::(3 * k as usize) }?; + let n_u64 = n as u64; + + let phase_cfg = LaunchConfig { + grid_dim: (k, 1, 1), + block_dim: (BLOCK_SIZE, 1, 1), + shared_mem_bytes: 0, + }; + + unsafe { + stream + .launch_builder(&be.block_inclusive_scan_rev_ext3) + .arg(input) + .arg(&n_u64) + .arg(&mut *suffix_out) + .arg(&mut block_totals) + .launch(phase_cfg)?; + } + + if k > 1 { + // The reverse-direction phase-2 is itself a forward inclusive scan + // of the (already reverse-indexed) block totals: block_totals[b] + // holds the product over the b-th REVERSE block, and we need an + // inclusive prefix over those for phase 3's offsets. + scan_inplace_fwd(stream, be, &mut block_totals, k as usize)?; + + unsafe { + stream + .launch_builder(&be.apply_block_offsets_rev_ext3) + .arg(&mut *suffix_out) + .arg(&n_u64) + .arg(&block_totals) + .launch(phase_cfg)?; + } + } + Ok(()) +} + +// ============================================================================= +// Host-side ext3 inverse (one element, used to invert the batch total). +// ============================================================================= + +/// Invert one ext3 element on the host via the math crate's `Fp3::inv`. +/// Used once per batch inverse to invert the total product; the main batch +/// inverse work stays on GPU. Returns a cudarc `DriverError` on zero norm +/// so the caller's `Err(_) => None` fallback path fires (instead of +/// panicking past it). +fn invert_ext3_host(x: [u64; 3]) -> Result<[u64; 3]> { + use math::field::element::FieldElement; + use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; + use math::field::goldilocks::GoldilocksField; + + type Fp = FieldElement; + type Fp3 = FieldElement; + + let elem = Fp3::new([Fp::from_raw(x[0]), Fp::from_raw(x[1]), Fp::from_raw(x[2])]); + let inv = elem.inv().map_err(|_| { + cudarc::driver::DriverError(cudarc::driver::sys::CUresult::CUDA_ERROR_UNKNOWN) + })?; + Ok([ + *inv.value()[0].value(), + *inv.value()[1].value(), + *inv.value()[2].value(), + ]) +} diff --git a/crypto/math-cuda/src/lib.rs b/crypto/math-cuda/src/lib.rs index a06481ba2..37f4bc2b7 100644 --- a/crypto/math-cuda/src/lib.rs +++ b/crypto/math-cuda/src/lib.rs @@ -8,10 +8,15 @@ pub mod barycentric; pub mod deep; pub mod device; pub mod fri; +pub mod inverse; pub mod lde; pub mod merkle; pub mod ntt; +// Re-exported for downstream crates so they can refer to CUDA primitive +// types without depending on cudarc directly. +pub use cudarc::driver::{CudaSlice, CudaStream}; + use cudarc::driver::{LaunchConfig, PushKernelArg}; use crate::device::{Backend, backend}; diff --git a/crypto/math-cuda/tests/batch_inverse.rs b/crypto/math-cuda/tests/batch_inverse.rs new file mode 100644 index 000000000..bc52f9fcb --- /dev/null +++ b/crypto/math-cuda/tests/batch_inverse.rs @@ -0,0 +1,106 @@ +//! Parity: GPU parallel batch inverse matches CPU +//! `FieldElement::inplace_batch_inverse` on ext3 elements. +//! +//! Sizes span: +//! - n=1 (host-only path) +//! - n in {2..256} small (single-block scan) +//! - n in {257..2^17} medium (multi-block, single recursion) +//! - n=2^20, 2^22 large (multi-block, two-level recursion) + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::IsPrimeField; +use math_cuda::inverse::batch_inverse_ext3; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +fn rand_fp(rng: &mut ChaCha8Rng) -> Fp { + loop { + let v = rng.r#gen::(); + if v != 0 { + return Fp::from_raw(v); + } + } +} + +fn rand_fp3_nonzero(rng: &mut ChaCha8Rng) -> Fp3 { + Fp3::new([rand_fp(rng), rand_fp(rng), rand_fp(rng)]) +} + +fn ext3_to_u64s(col: &[Fp3]) -> Vec { + let mut out = Vec::with_capacity(col.len() * 3); + for e in col { + out.push(*e.value()[0].value()); + out.push(*e.value()[1].value()); + out.push(*e.value()[2].value()); + } + out +} + +fn canon3(a: &[u64]) -> Vec { + a.iter().map(GoldilocksField::canonical).collect() +} + +fn run(n: usize, seed: u64) { + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let xs: Vec = (0..n).map(|_| rand_fp3_nonzero(&mut rng)).collect(); + + let mut cpu = xs.clone(); + FieldElement::inplace_batch_inverse(&mut cpu).expect("batch inverse non-zero"); + + let input_u64 = ext3_to_u64s(&xs); + let gpu_u64 = batch_inverse_ext3(&input_u64).unwrap(); + + let cpu_u64 = ext3_to_u64s(&cpu); + let gpu_canon = canon3(&gpu_u64); + let cpu_canon = canon3(&cpu_u64); + + for i in 0..n { + let g = &gpu_canon[i * 3..(i + 1) * 3]; + let c = &cpu_canon[i * 3..(i + 1) * 3]; + assert_eq!(g, c, "mismatch at i={i} n={n}"); + } +} + +#[test] +fn batch_inverse_n1() { + // Host-only special case. + run(1, 1); +} + +#[test] +fn batch_inverse_single_block() { + // All single-block sizes (no recursion). + for n in [2usize, 3, 5, 16, 63, 127, 255, 256] { + run(n, 100 + n as u64); + } +} + +#[test] +fn batch_inverse_two_block() { + // Just over single-block: forces phase 1 + 3 with K = 2. + for n in [257usize, 511, 512, 513, 1024] { + run(n, 200 + n as u64); + } +} + +#[test] +fn batch_inverse_multi_block() { + // Multi-block, single level of recursion (K > 1, K <= 256). + for n in [4096usize, 16384, 65536] { + run(n, 500 + n as u64); + } +} + +#[test] +fn batch_inverse_recursive() { + // K > 256: forces two levels of recursion. fib_iterative_1M + // (lde_size=2^20) and fib_iterative_4M (lde_size=2^22) shapes. + run(1 << 18, 9001); + run(1 << 20, 9002); + run(1 << 22, 9003); +} diff --git a/crypto/math-cuda/tests/compute_and_invert_denoms.rs b/crypto/math-cuda/tests/compute_and_invert_denoms.rs new file mode 100644 index 000000000..a00da8b23 --- /dev/null +++ b/crypto/math-cuda/tests/compute_and_invert_denoms.rs @@ -0,0 +1,112 @@ +//! Parity: GPU `compute_and_invert_denoms_ext3_dev` matches the CPU +//! reference `denoms[k * n + i] = x_lde[i] - z[k]` followed by +//! `inplace_batch_inverse`. Mirrors the shapes used by R3 OOD (n = +//! trace_size, k = num_eval_points) and R4 DEEP (n = lde_size, k = +//! 1 + num_eval_points). + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::IsPrimeField; +use math_cuda::device::backend; +use math_cuda::inverse::{DenomSign, compute_and_invert_denoms_ext3_dev}; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +fn rand_fp(rng: &mut ChaCha8Rng) -> Fp { + Fp::from_raw(rng.r#gen::()) +} + +fn rand_fp3(rng: &mut ChaCha8Rng) -> Fp3 { + Fp3::new([rand_fp(rng), rand_fp(rng), rand_fp(rng)]) +} + +fn ext3_to_u64s(col: &[Fp3]) -> Vec { + let mut out = Vec::with_capacity(col.len() * 3); + for e in col { + out.push(*e.value()[0].value()); + out.push(*e.value()[1].value()); + out.push(*e.value()[2].value()); + } + out +} + +fn canon3(a: &[u64]) -> Vec { + a.iter().map(GoldilocksField::canonical).collect() +} + +fn run(n: usize, k_scalars: usize, sign: DenomSign, seed: u64) { + let mut rng = ChaCha8Rng::seed_from_u64(seed); + + // x_lde: base-field, n elements. Avoid the trivial case where x_lde[i] + // happens to equal a z_scalars[k] component (that would make a denom + // zero and trigger the batch-invert zero-norm assert). + let x_lde: Vec = (0..n).map(|_| rand_fp(&mut rng)).collect(); + let z_scalars: Vec = (0..k_scalars).map(|_| rand_fp3(&mut rng)).collect(); + + // CPU reference: denom layout depends on `sign`. + let mut denoms_cpu: Vec = Vec::with_capacity(n * k_scalars); + for z in &z_scalars { + for x in &x_lde { + let x_lifted = Fp3::new([*x, Fp::zero(), Fp::zero()]); + let d = match sign { + DenomSign::ZMinusX => z - &x_lifted, + DenomSign::XMinusZ => &x_lifted - z, + }; + denoms_cpu.push(d); + } + } + FieldElement::inplace_batch_inverse(&mut denoms_cpu).expect("denoms non-zero"); + + // GPU: H2D x_lde, then run the fused compute+invert. + let be = backend().unwrap(); + let stream = be.next_stream(); + let x_u64: Vec = x_lde.iter().map(|x| *x.value()).collect(); + let x_dev = stream.clone_htod(&x_u64).unwrap(); + let z_u64 = ext3_to_u64s(&z_scalars); + let inv_dev = + compute_and_invert_denoms_ext3_dev(&x_dev, &z_u64, n, k_scalars, sign, &stream).unwrap(); + let gpu_u64: Vec = stream.clone_dtoh(&inv_dev).unwrap(); + stream.synchronize().unwrap(); + + let cpu_u64 = ext3_to_u64s(&denoms_cpu); + let gpu_canon = canon3(&gpu_u64); + let cpu_canon = canon3(&cpu_u64); + + for i in 0..(n * k_scalars) { + let g = &gpu_canon[i * 3..(i + 1) * 3]; + let c = &cpu_canon[i * 3..(i + 1) * 3]; + assert_eq!( + g, + c, + "mismatch at flat={i} (k={}, idx={}) n={n} k_scalars={k_scalars}", + i / n, + i % n + ); + } +} + +#[test] +fn denoms_small_both_signs() { + // Tiny shapes for fast-feedback debugging, both sign conventions. + run(8, 1, DenomSign::ZMinusX, 100); + run(8, 1, DenomSign::XMinusZ, 101); + run(16, 3, DenomSign::ZMinusX, 200); + run(64, 5, DenomSign::XMinusZ, 300); +} + +#[test] +fn denoms_r3_ood_shape() { + // R3 OOD: n = trace_size, k = num_eval_points (z - x convention). + run(1 << 14, 4, DenomSign::ZMinusX, 400); + run(1 << 16, 4, DenomSign::ZMinusX, 500); +} + +#[test] +fn denoms_r4_deep_shape() { + // R4 DEEP: n = lde_size, k = 1 + num_eval_points (x - z convention). + run(1 << 18, 5, DenomSign::XMinusZ, 600); +} diff --git a/crypto/stark/Cargo.toml b/crypto/stark/Cargo.toml index 24ff4f0c2..d0f6a51ef 100644 --- a/crypto/stark/Cargo.toml +++ b/crypto/stark/Cargo.toml @@ -41,6 +41,8 @@ criterion = { version = "0.4", default-features = false } env_logger = "*" test-log = { version = "0.2.11", features = ["log"] } bincode = "1" +rand = { version = "0.8.5", features = ["std"] } +rand_chacha = "0.3.1" [features] test-utils = [] diff --git a/crypto/stark/src/gpu_lde.rs b/crypto/stark/src/gpu_lde.rs index e797cfe3a..36756b40b 100644 --- a/crypto/stark/src/gpu_lde.rs +++ b/crypto/stark/src/gpu_lde.rs @@ -8,9 +8,12 @@ use core::mem::transmute_copy; use std::any::TypeId; use std::slice::{from_raw_parts, from_raw_parts_mut}; +use std::sync::Arc; use std::sync::OnceLock; use std::sync::atomic::{AtomicU64, Ordering}; +use math_cuda::{CudaSlice, CudaStream}; + use crypto::fiat_shamir::is_transcript::IsStarkTranscript; use crypto::merkle_tree::merkle::MerkleTree; use crypto::merkle_tree::traits::IsMerkleTreeBackend; @@ -70,6 +73,7 @@ pub fn reset_all_gpu_call_counters() { GPU_COMP_POLY_TREE_CALLS.store(0, Ordering::Relaxed); GPU_DEEP_CALLS.store(0, Ordering::Relaxed); GPU_FRI_CALLS.store(0, Ordering::Relaxed); + GPU_BATCH_INVERT_CALLS.store(0, Ordering::Relaxed); } pub(crate) static GPU_EXTEND_HALVES_CALLS: AtomicU64 = AtomicU64::new(0); @@ -716,7 +720,7 @@ where FieldElement::::from_raw(sums_raw[c * 3 + 1]), FieldElement::::from_raw(sums_raw[c * 3 + 2]), ]); - let final_ext3 = &s * &scalar_e; + let final_ext3 = s * scalar_e; // SAFETY: TypeId-checked at the caller. E == Ext3, identical layout. let final_e: FieldElement = unsafe { transmute_copy::, FieldElement>(&final_ext3) }; @@ -812,7 +816,8 @@ pub(crate) fn try_barycentric_base_on_handle( n_inv: &FieldElement, g_n_inv: &FieldElement, z_pow_n: &FieldElement, - inv_denoms: &[FieldElement], + inv_denoms_host: &[FieldElement], + r3_ctx: Option<(&R3DevContext, usize)>, ) -> Option>> where F: IsField + IsSubFieldOf + 'static, @@ -833,28 +838,49 @@ where if !n.is_power_of_two() || n < gpu_bary_threshold() { return None; } - if inv_denoms.len() != n || main.lde_size != n.checked_mul(row_stride)? { + if main.lde_size != n.checked_mul(row_stride)? { + return None; + } + // Host inv_denoms length only matters on the host path. + if r3_ctx.is_none() && inv_denoms_host.len() != n { return None; } // SAFETY: F == Goldilocks per TypeId check; FieldElement is // #[repr(transparent)] over u64. let points_raw: &[u64] = unsafe { from_raw_parts(coset_points.as_ptr() as *const u64, n) }; - // SAFETY: E == Ext3 per TypeId check; FieldElement backing is - // `[FieldElement; 3]` = `[u64; 3]`. - let inv_denoms_len = n.checked_mul(3).expect("inv_denoms u64 len overflow"); - let inv_denoms_raw: &[u64] = - unsafe { from_raw_parts(inv_denoms.as_ptr() as *const u64, inv_denoms_len) }; - - let sums_raw = match math_cuda::barycentric::barycentric_base_on_device( - main, - row_stride, - points_raw, - inv_denoms_raw, - n, - ) { - Ok(v) => v, - Err(_) => return None, + + let sums_raw = match r3_ctx { + Some((ctx, inv_offset_u64)) => { + match math_cuda::barycentric::barycentric_base_on_device_with_dev_inv_denoms( + &ctx.stream, + main, + row_stride, + &ctx.coset_points, + &ctx.inv_denoms, + inv_offset_u64, + n, + ) { + Ok(v) => v, + Err(_) => return None, + } + } + None => { + // SAFETY: E == Ext3 per TypeId check; FieldElement backing is `[u64; 3]`. + let inv_denoms_len = n.checked_mul(3).expect("inv_denoms u64 len overflow"); + let inv_denoms_raw: &[u64] = + unsafe { from_raw_parts(inv_denoms_host.as_ptr() as *const u64, inv_denoms_len) }; + match math_cuda::barycentric::barycentric_base_on_device( + main, + row_stride, + points_raw, + inv_denoms_raw, + n, + ) { + Ok(v) => v, + Err(_) => return None, + } + } }; GPU_BARY_CALLS.fetch_add(1, Ordering::Relaxed); @@ -873,7 +899,8 @@ pub(crate) fn try_barycentric_ext3_on_handle( n_inv: &FieldElement, g_n_inv: &FieldElement, z_pow_n: &FieldElement, - inv_denoms: &[FieldElement], + inv_denoms_host: &[FieldElement], + r3_ctx: Option<(&R3DevContext, usize)>, ) -> Option>> where F: IsField + IsSubFieldOf + 'static, @@ -894,24 +921,45 @@ where if !n.is_power_of_two() || n < gpu_bary_threshold() { return None; } - if inv_denoms.len() != n || aux.lde_size != n.checked_mul(row_stride)? { + if aux.lde_size != n.checked_mul(row_stride)? { + return None; + } + if r3_ctx.is_none() && inv_denoms_host.len() != n { return None; } let points_raw: &[u64] = unsafe { from_raw_parts(coset_points.as_ptr() as *const u64, n) }; - let inv_denoms_len = n.checked_mul(3).expect("inv_denoms u64 len overflow"); - let inv_denoms_raw: &[u64] = - unsafe { from_raw_parts(inv_denoms.as_ptr() as *const u64, inv_denoms_len) }; - - let sums_raw = match math_cuda::barycentric::barycentric_ext3_on_device( - aux, - row_stride, - points_raw, - inv_denoms_raw, - n, - ) { - Ok(v) => v, - Err(_) => return None, + + let sums_raw = match r3_ctx { + Some((ctx, inv_offset_u64)) => { + match math_cuda::barycentric::barycentric_ext3_on_device_with_dev_inv_denoms( + &ctx.stream, + aux, + row_stride, + &ctx.coset_points, + &ctx.inv_denoms, + inv_offset_u64, + n, + ) { + Ok(v) => v, + Err(_) => return None, + } + } + None => { + let inv_denoms_len = n.checked_mul(3).expect("inv_denoms u64 len overflow"); + let inv_denoms_raw: &[u64] = + unsafe { from_raw_parts(inv_denoms_host.as_ptr() as *const u64, inv_denoms_len) }; + match math_cuda::barycentric::barycentric_ext3_on_device( + aux, + row_stride, + points_raw, + inv_denoms_raw, + n, + ) { + Ok(v) => v, + Err(_) => return None, + } + } }; GPU_BARY_CALLS.fetch_add(1, Ordering::Relaxed); @@ -936,6 +984,16 @@ pub fn gpu_fri_calls() -> u64 { GPU_FRI_CALLS.load(Ordering::Relaxed) } +/// Batch-invert dispatch counter (one per +/// [`try_compute_and_invert_inv_denoms_dev`] call that actually built a +/// device handle). Fires at most twice per prove per table: once for R3 +/// OOD's `num_eval_points * trace_size` denominators and once for R4 +/// DEEP's `(1 + num_eval_points) * lde_size` denominators. +pub(crate) static GPU_BATCH_INVERT_CALLS: AtomicU64 = AtomicU64::new(0); +pub fn gpu_batch_invert_calls() -> u64 { + GPU_BATCH_INVERT_CALLS.load(Ordering::Relaxed) +} + /// Test-only: schedule the Nth upcoming FRI fold call (1 = first, 2 = /// second, ...) to return Err, exercising the snapshot-restore path in /// [`try_fri_commit_gpu`]. Pass -1 to disable. Production default is -1. @@ -945,6 +1003,15 @@ pub fn schedule_fri_fold_fault(n_calls_until_err: i64) { math_cuda::fri::FAULT_FOLDS_REMAINING_UNTIL_ERR.store(n_calls_until_err, Ordering::Relaxed); } +/// Test-only: schedule the Nth upcoming `compute_and_invert_denoms_ext3_dev` +/// call to return Err, exercising the CPU-fallback path in +/// [`try_compute_and_invert_inv_denoms_dev`]. Pass -1 to disable. +#[cfg(feature = "test-cuda-faults")] +pub fn schedule_inverse_fault(n_calls_until_err: i64) { + math_cuda::inverse::FAULT_INVERSE_REMAINING_UNTIL_ERR + .store(n_calls_until_err, Ordering::Relaxed); +} + /// R2 GPU dispatch: batched ext3 LDE over `parts_coefs` (composition-poly /// coefficient parts). Returns both the host LDE eval Vecs (needed for the /// R2 Merkle commit and R3 OOD path) and a device-resident `GpuLdeExt3` @@ -1084,7 +1151,8 @@ pub(crate) fn try_deep_composition_gpu( trace_ood_columns: &[Vec>], composition_poly_gammas: &[FieldElement], trace_terms_gammas: &[Vec>], - inv_denoms: &[FieldElement], + inv_denoms_host: &[FieldElement], + inv_denoms_dev: Option<(&CudaSlice, &Arc)>, num_eval_points: usize, ) -> Option>> where @@ -1126,7 +1194,15 @@ where return None; } let expected_inv_denoms = lde_size.checked_mul(1 + num_eval_points)?; - if inv_denoms.len() != expected_inv_denoms { + // The fully-resident `(Some(parts), Some(dev_inv))` arm ignores the + // host inv_denoms slice; every other arm slices into it. Validate the + // host length whenever the chosen arm will consume it, even when a + // dev inv_denoms handle is also present (a (None, Some) combination + // is reachable when R2's keep path missed but the batch-invert + // dispatch succeeded; without this guard that path would panic + // slicing an empty host buffer). + let arm_needs_host_inv = !(parts_dev.is_some() && inv_denoms_dev.is_some()); + if arm_needs_host_inv && inv_denoms_host.len() != expected_inv_denoms { return None; } @@ -1164,69 +1240,102 @@ where gammas_tr_raw.extend_from_slice(slice); } - // inv_denoms is laid out as (1 + num_eval_points) blocks of lde_size - // each. Split the H-term block and the trace blocks (concatenated). - let inv_h_raw: &[u64] = unsafe { ext3_slice_to_u64::(&inv_denoms[0..lde_size]) }; - let inv_t_raw: &[u64] = - unsafe { ext3_slice_to_u64::(&inv_denoms[lde_size..lde_size * (1 + num_eval_points)]) }; - // domain_size == lde_size here: R4 DEEP evaluates at every LDE point // (Plonky3-style direct LDE). Calling the kernel with row_stride = 1 // makes its `row = i * row_stride` index every row. let domain_size_kernel = lde_size; let row_stride_kernel = 1usize; - // Pack parts host path if no device handle. + // Three dispatch paths, in priority order: + // 1. Both parts + inv_denoms on device: the fully-resident path. + // Requires the caller's stream so the new inv_denoms_dev producer + // and this kernel run on the same queue (no cross-stream race). + // 2. Parts on device, inv_denoms on host. + // 3. Both on host (fallback when R2 keep + denom-invert both missed). let parts_host_packed: Vec; - let result = if let Some(parts) = parts_dev { - math_cuda::deep::deep_composition_ext3_with_dev_parts( - main, - aux_handle, - parts, - h_ood_raw, - &trace_ood_raw, - gammas_h_raw, - &gammas_tr_raw, - inv_h_raw, - inv_t_raw, - num_parts, - num_main, - num_aux, - num_eval_points, - row_stride_kernel, - domain_size_kernel, - ) - } else { - // De-interleave each ext3 part column into 3 contiguous base-field - // slabs of length `lde_size` (the math-cuda kernel reads the parts - // buffer with layout `h_lde[(p*3 + k) * lde_stride + r]`). - let mut packed = vec![0u64; num_parts * 3 * lde_size]; - for (p, col) in parts_host.iter().enumerate() { - let slice = unsafe { ext3_slice_to_u64::(col) }; - for (r, chunk) in slice.chunks_exact(3).enumerate() { - packed[(p * 3) * lde_size + r] = chunk[0]; - packed[(p * 3 + 1) * lde_size + r] = chunk[1]; - packed[(p * 3 + 2) * lde_size + r] = chunk[2]; + let result = match (parts_dev, inv_denoms_dev) { + (Some(parts), Some((inv_dev, stream))) => { + math_cuda::deep::deep_composition_ext3_with_dev_parts_and_inv_denoms( + stream, + main, + aux_handle, + parts, + inv_dev, + h_ood_raw, + &trace_ood_raw, + gammas_h_raw, + &gammas_tr_raw, + num_parts, + num_main, + num_aux, + num_eval_points, + row_stride_kernel, + domain_size_kernel, + ) + } + (Some(parts), None) => { + let inv_h_raw: &[u64] = + unsafe { ext3_slice_to_u64::(&inv_denoms_host[0..lde_size]) }; + let inv_t_raw: &[u64] = unsafe { + ext3_slice_to_u64::(&inv_denoms_host[lde_size..lde_size * (1 + num_eval_points)]) + }; + math_cuda::deep::deep_composition_ext3_with_dev_parts( + main, + aux_handle, + parts, + h_ood_raw, + &trace_ood_raw, + gammas_h_raw, + &gammas_tr_raw, + inv_h_raw, + inv_t_raw, + num_parts, + num_main, + num_aux, + num_eval_points, + row_stride_kernel, + domain_size_kernel, + ) + } + (None, _) => { + // De-interleave each ext3 part column into 3 contiguous base-field + // slabs of length `lde_size` (the math-cuda kernel reads the parts + // buffer with layout `h_lde[(p*3 + k) * lde_stride + r]`). + let mut packed = vec![0u64; num_parts * 3 * lde_size]; + for (p, col) in parts_host.iter().enumerate() { + let slice = unsafe { ext3_slice_to_u64::(col) }; + for (r, chunk) in slice.chunks_exact(3).enumerate() { + packed[(p * 3) * lde_size + r] = chunk[0]; + packed[(p * 3 + 1) * lde_size + r] = chunk[1]; + packed[(p * 3 + 2) * lde_size + r] = chunk[2]; + } } + parts_host_packed = packed; + // Host inv_denoms required when going through this path; we + // validated the slice length above. + let inv_h_raw: &[u64] = + unsafe { ext3_slice_to_u64::(&inv_denoms_host[0..lde_size]) }; + let inv_t_raw: &[u64] = unsafe { + ext3_slice_to_u64::(&inv_denoms_host[lde_size..lde_size * (1 + num_eval_points)]) + }; + math_cuda::deep::deep_composition_ext3( + main, + aux_handle, + &parts_host_packed, + h_ood_raw, + &trace_ood_raw, + gammas_h_raw, + &gammas_tr_raw, + inv_h_raw, + inv_t_raw, + num_parts, + num_main, + num_aux, + num_eval_points, + row_stride_kernel, + domain_size_kernel, + ) } - parts_host_packed = packed; - math_cuda::deep::deep_composition_ext3( - main, - aux_handle, - &parts_host_packed, - h_ood_raw, - &trace_ood_raw, - gammas_h_raw, - &gammas_tr_raw, - inv_h_raw, - inv_t_raw, - num_parts, - num_main, - num_aux, - num_eval_points, - row_stride_kernel, - domain_size_kernel, - ) }; let deep_raw = match result { @@ -1238,6 +1347,166 @@ where Some(u64_to_ext3_vec::(&deep_raw)) } +/// Build `inv_denoms[k*n + i] = 1 / (lift(coset_base[i]) - z_scalars[k])` +/// entirely on device. Used by both R3 OOD (n = trace_size, k_scalars = +/// num_eval_points) and R4 DEEP (n = lde_size, k_scalars = 1 + +/// num_eval_points). Returns a device handle the caller can slice and +/// thread into downstream dispatchers without ever D2H'ing the inverted +/// values; on type / threshold / cudarc failure returns `None` so the +/// caller can fall back to CPU `inplace_batch_inverse`. +/// +/// The threshold check uses `gpu_lde_threshold()` against `n * k_scalars`, +/// matching the rest of the dispatch layer. +pub(crate) fn try_compute_and_invert_inv_denoms_dev( + coset_base: &[FieldElement], + z_scalars: &[FieldElement], + sign: math_cuda::inverse::DenomSign, + stream: &Arc, +) -> Option> +where + F: IsField + 'static, + E: IsField + 'static, +{ + if TypeId::of::() != TypeId::of::() { + return None; + } + if TypeId::of::() != TypeId::of::() { + return None; + } + let n = coset_base.len(); + let k_scalars = z_scalars.len(); + if n == 0 || k_scalars == 0 { + return None; + } + let total = n.checked_mul(k_scalars)?; + if total < gpu_lde_threshold() { + return None; + } + + // SAFETY: F == Goldilocks per TypeId check; FieldElement is + // #[repr(transparent)] over u64. + let coset_u64: &[u64] = unsafe { from_raw_parts(coset_base.as_ptr() as *const u64, n) }; + let coset_dev = match stream.clone_htod(coset_u64) { + Ok(s) => s, + Err(_) => return None, + }; + + // SAFETY: E == Ext3 per TypeId check. + let z_u64: &[u64] = unsafe { ext3_slice_to_u64::(z_scalars) }; + + let result = math_cuda::inverse::compute_and_invert_denoms_ext3_dev( + &coset_dev, z_u64, n, k_scalars, sign, stream, + ); + match result { + Ok(handle) => { + GPU_BATCH_INVERT_CALLS.fetch_add(1, Ordering::Relaxed); + Some(handle) + } + Err(_) => None, + } +} + +/// Convenience wrapper for prover callers that don't yet own a stream: +/// acquires the math-cuda backend, allocates a fresh stream, and produces +/// a device-resident `inv_denoms` buffer plus the stream that owns it. +/// The caller passes the tuple through to the downstream dispatch +/// functions (`try_barycentric_*_on_handle`, `try_deep_composition_gpu`) +/// so every kernel touching the buffer runs on the same stream (no +/// cross-stream race). +/// +/// Returns `None` on type / threshold mismatch, backend init failure, or +/// any cudarc error; the caller falls back to its CPU +/// `inplace_batch_inverse` loop. +pub(crate) fn try_inv_denoms_dev_with_stream( + coset_base: &[FieldElement], + z_scalars: &[FieldElement], + sign: math_cuda::inverse::DenomSign, +) -> Option<(CudaSlice, Arc)> +where + F: IsField + 'static, + E: IsField + 'static, +{ + let be = math_cuda::device::backend().ok()?; + let stream = be.next_stream(); + let handle = + try_compute_and_invert_inv_denoms_dev::(coset_base, z_scalars, sign, &stream)?; + Some((handle, stream)) +} + +/// R3 OOD device-side context: bundles the inverted denominators, the +/// coset_points upload (used by every barycentric kernel for this batch), +/// and the stream so producer + consumers serialize naturally. Hoisting +/// `coset_points` here means the barycentric kernels read the same +/// device buffer across `num_eval_points * {main, aux}` calls instead +/// of re-uploading `dc.points` each iteration. +#[derive(Debug)] +pub(crate) struct R3DevContext { + pub inv_denoms: CudaSlice, + pub coset_points: CudaSlice, + pub stream: Arc, +} + +/// Build an [`R3DevContext`] in one stream: acquire backend, allocate +/// stream, H2D coset_points once, then run `compute_and_invert_denoms` +/// against that same handle so the coset H2D isn't repeated by any +/// downstream barycentric kernel. +/// +/// Returns `None` on type / threshold mismatch, backend init failure, or +/// any cudarc error. +pub(crate) fn try_prep_r3_dev_context( + coset_base: &[FieldElement], + z_scalars: &[FieldElement], +) -> Option +where + F: IsField + 'static, + E: IsField + 'static, +{ + if TypeId::of::() != TypeId::of::() { + return None; + } + if TypeId::of::() != TypeId::of::() { + return None; + } + let n = coset_base.len(); + let k_scalars = z_scalars.len(); + if n == 0 || k_scalars == 0 { + return None; + } + let total = n.checked_mul(k_scalars)?; + if total < gpu_lde_threshold() { + return None; + } + + let be = math_cuda::device::backend().ok()?; + let stream = be.next_stream(); + + // SAFETY: F == Goldilocks per TypeId check; FieldElement is + // #[repr(transparent)] over u64. + let coset_u64: &[u64] = unsafe { from_raw_parts(coset_base.as_ptr() as *const u64, n) }; + let coset_points = stream.clone_htod(coset_u64).ok()?; + + // SAFETY: E == Ext3 per TypeId check. + let z_u64: &[u64] = unsafe { ext3_slice_to_u64::(z_scalars) }; + + let inv_denoms = match math_cuda::inverse::compute_and_invert_denoms_ext3_dev( + &coset_points, + z_u64, + n, + k_scalars, + math_cuda::inverse::DenomSign::ZMinusX, + &stream, + ) { + Ok(h) => h, + Err(_) => return None, + }; + GPU_BATCH_INVERT_CALLS.fetch_add(1, Ordering::Relaxed); + Some(R3DevContext { + inv_denoms, + coset_points, + stream, + }) +} + /// R4 FRI dispatch: drive the full FRI commit phase device-side. Mirrors /// [`crate::fri::commit_phase_from_evaluations`]: per-layer transcript /// ping-pong (sample zeta, fold, build Merkle tree, append root). diff --git a/crypto/stark/src/lib.rs b/crypto/stark/src/lib.rs index 3ae8415c1..e9f6a1cda 100644 --- a/crypto/stark/src/lib.rs +++ b/crypto/stark/src/lib.rs @@ -22,6 +22,7 @@ pub mod lookup; pub(crate) mod par; pub mod proof; pub mod prover; +pub mod r4_denoms; #[cfg(feature = "disk-spill")] pub mod storage_mode; pub mod table; diff --git a/crypto/stark/src/prover.rs b/crypto/stark/src/prover.rs index 601195ffb..46261103e 100644 --- a/crypto/stark/src/prover.rs +++ b/crypto/stark/src/prover.rs @@ -1349,41 +1349,63 @@ pub trait IsStarkProver< // Number of main and aux columns in the LDE trace let num_main_cols = lde_trace.num_main_cols(); let num_aux_cols = lde_trace.num_aux_cols(); - - // Precompute all inverse denominators at ALL LDE points via batch inversion. let lde_size = domain.lde_roots_of_unity_coset.len(); - let num_denoms = lde_size * (1 + num_eval_points); - let mut denoms: Vec> = Vec::with_capacity(num_denoms); - // H-term denominators: x_i - z^K (all 2N LDE points) - for i in 0..lde_size { - let x_i = &domain.lde_roots_of_unity_coset[i]; - denoms.push(x_i - &z_power); - } + // OOD evaluations + let h_ood = &round_3_result.composition_poly_parts_ood_evaluation; + let trace_ood_columns = round_3_result.trace_ood_evaluations.columns(); + let num_total_cols = num_main_cols + num_aux_cols; - // Trace-term denominators: x_i - z_shifted[k] (all 2N LDE points) - for z_k in z_shifted.iter().take(num_eval_points) { - for i in 0..lde_size { - let x_i = &domain.lde_roots_of_unity_coset[i]; - denoms.push(x_i - z_k); + // Fully device-resident GPU fast path: build inv_denoms on device + // ([z^K, z_shifted[0..]] over the full LDE coset), then run R4 + // DEEP composition reading the same device buffer. Skips the + // CPU `inplace_batch_inverse` on the happy path; on any GPU + // failure we fall through and compute denoms on CPU below. + #[cfg(feature = "cuda")] + { + let z_scalars: Vec> = core::iter::once(z_power.clone()) + .chain(z_shifted.iter().cloned()) + .collect(); + if let Some((inv_dev, stream)) = + crate::gpu_lde::try_inv_denoms_dev_with_stream::( + &domain.lde_roots_of_unity_coset, + &z_scalars, + math_cuda::inverse::DenomSign::XMinusZ, + ) + && let Some(deep_evals) = + crate::gpu_lde::try_deep_composition_gpu::( + lde_trace, + round_2_result.gpu_composition_parts.as_ref(), + &round_2_result.lde_composition_poly_evaluations, + h_ood, + &trace_ood_columns, + composition_poly_gammas, + trace_terms_gammas, + &[], + Some((&inv_dev, &stream)), + num_eval_points, + ) + { + return deep_evals; } } - FieldElement::inplace_batch_inverse(&mut denoms) - .expect("Denominators should be non-zero: coset points are base field, poles are extension field"); + // CPU denoms + batch inverse for the fallback paths below. + // Single-source helper shared with the GPU parity test so any + // sign/ordering/layout drift breaks the test instead of silently + // diverging CUDA vs non-CUDA proofs. + let denoms = crate::r4_denoms::build_r4_inv_denoms_cpu::( + &domain.lde_roots_of_unity_coset, + &z_power, + &z_shifted, + ) + .expect("R4 inv denoms: coset points are base field, poles are extension field"); let inv_h = &denoms[0..lde_size]; - // OOD evaluations - let h_ood = &round_3_result.composition_poly_parts_ood_evaluation; - let trace_ood_columns = round_3_result.trace_ood_evaluations.columns(); - let num_total_cols = num_main_cols + num_aux_cols; - - // GPU fast path: device-resident DEEP composition. Reuses the R1 - // main/aux LDE handles on `lde_trace` and (when the R2 fused path - // ran) the parts handle on `round_2_result.gpu_composition_parts`. - // Falls back to the CPU rayon loop below on any precondition miss - // or kernel failure. + // GPU mixed path: dev parts (when R2 keep handle exists) + host + // inv_denoms. Used when the dev-inv-denoms path above didn't fire + // (e.g., cudarc error in compute_denoms / scan). #[cfg(feature = "cuda")] { if let Some(deep_evals) = @@ -1396,6 +1418,7 @@ pub trait IsStarkProver< composition_poly_gammas, trace_terms_gammas, &denoms, + None, num_eval_points, ) { diff --git a/crypto/stark/src/r4_denoms.rs b/crypto/stark/src/r4_denoms.rs new file mode 100644 index 000000000..77076ecfe --- /dev/null +++ b/crypto/stark/src/r4_denoms.rs @@ -0,0 +1,45 @@ +//! Single-source builder for R4 DEEP inverse denominators on CPU. +//! +//! Called by both the prover's CPU fallback in +//! `compute_deep_composition_poly_evaluations` and by the GPU parity test +//! that pins this construction against the device pipeline +//! (`compute_and_invert_denoms_ext3_dev`). Keeping it in one place means a +//! sign/ordering/layout drift cannot diverge CUDA and non-CUDA builds +//! silently. +//! +//! Convention (mirrors `compute_and_invert_denoms_ext3_dev` with +//! `DenomSign::XMinusZ`): +//! - `z_scalars = [z_power, z_shifted[0..]]`, length `1 + z_shifted.len()` +//! - `denoms[k * lde_size + i] = x_i - z_scalars[k]` (then inverted) + +use math::field::element::FieldElement; +use math::field::traits::{IsField, IsSubFieldOf}; + +/// Build `1 / (x_i - z_k)` for k in [0..=z_shifted.len()] and i in [0..n) +/// where `z = [z_power, z_shifted[0..]]`. Output is flat, k-major: +/// `out[k * coset.len() + i] = (x_i - z_k)^{-1}`. +/// +/// Returns `Err` only if `inplace_batch_inverse` hits a zero element, +/// which is unreachable in honest proving (Fiat-Shamir `z` on the LDE +/// coset is negligible) but the contract follows lambdaworks' API. +pub fn build_r4_inv_denoms_cpu( + coset: &[FieldElement], + z_power: &FieldElement, + z_shifted: &[FieldElement], +) -> Result>, &'static str> +where + F: IsField + IsSubFieldOf, + E: IsField, +{ + let n = coset.len(); + let num_denoms = n * (1 + z_shifted.len()); + let mut denoms: Vec> = Vec::with_capacity(num_denoms); + for z_k in core::iter::once(z_power).chain(z_shifted.iter()) { + for x_i in coset { + denoms.push(x_i - z_k); + } + } + FieldElement::inplace_batch_inverse(&mut denoms) + .map_err(|_| "R4 inv denoms: zero denominator (z hit the LDE coset)")?; + Ok(denoms) +} diff --git a/crypto/stark/src/trace.rs b/crypto/stark/src/trace.rs index f4469447d..f63aa72de 100644 --- a/crypto/stark/src/trace.rs +++ b/crypto/stark/src/trace.rs @@ -460,7 +460,23 @@ where let mut table_data = Vec::with_capacity(evaluation_points.len() * table_width); - for eval_point in &evaluation_points { + // GPU fast path for R3 OOD: bundle the inverted inv_denoms (all + // eval points in one buffer) and the trace-size coset_points upload + // into a single device context. The barycentric kernels below read + // both via offset, with no per-eval-point or per-{main,aux} H2D. + #[cfg(feature = "cuda")] + let r3_ctx: Option = + crate::gpu_lde::try_prep_r3_dev_context::(&dc.points, &evaluation_points); + #[allow(unused_variables)] + #[cfg(not(feature = "cuda"))] + let r3_ctx: Option<()> = None; + + #[cfg_attr(not(feature = "cuda"), allow(clippy::unused_enumerate_index))] + for (eval_point_idx, eval_point) in evaluation_points.iter().enumerate() { + // Silence unused warning under non-cuda where eval_point_idx is + // only read inside the cuda-only block below. + #[cfg(not(feature = "cuda"))] + let _ = eval_point_idx; // z_pow_n for this evaluation point let z_pow_n = eval_point.pow(n); @@ -468,11 +484,20 @@ where let vanishing = z_pow_n.sub_subfield(&dc.offset_pow_n); let vanishing_factor = &n_inv_g_n_inv * &vanishing; - // Precompute inv_denoms = 1/(eval_point - coset_point_i), shared across all columns. - // Stays on CPU: the batch-invert cost at this scale (n * num_eval_points) is already - // rayon-parallelised across tables, and a GPU port regressed wall time in a - // 2x15-trial A/B due to stream contention from many concurrent launches. - let inv_denoms = barycentric_inv_denoms(eval_point, &dc.points); + // CPU inv_denoms = 1/(eval_point - coset_point_i). Materialised + // eagerly only when the GPU dispatcher will need to H2D it (no + // device-side inv_denoms buffer available). On the all-GPU happy + // path it stays None and the `barycentric_inv_denoms` call is + // skipped entirely (the GPU buffer covers every eval point). + #[cfg(feature = "cuda")] + let mut inv_denoms: Option>> = if r3_ctx.is_some() { + None + } else { + Some(barycentric_inv_denoms(eval_point, &dc.points)) + }; + #[cfg(not(feature = "cuda"))] + let mut inv_denoms: Option>> = + Some(barycentric_inv_denoms(eval_point, &dc.points)); // col_scale[i] = point[i] * inv_denom[i], shared across ALL CPU column // loops below. Computed lazily on first CPU-fallback use so the all-GPU @@ -484,6 +509,10 @@ where // for this table (handle absent), the size is below threshold, types // don't match, or the math-cuda call errored. Caller falls through // to the existing rayon CPU loop. + // Per-eval-point block offset into the GPU inv_denoms buffer: + // block k starts at u64 index k * 3 * n. + #[cfg(feature = "cuda")] + let r3_arg = r3_ctx.as_ref().map(|ctx| (ctx, eval_point_idx * 3 * n)); #[cfg(feature = "cuda")] let main_gpu = crate::gpu_lde::try_barycentric_base_on_handle::( lde_trace, @@ -493,7 +522,8 @@ where &dc.size_inv, &dc.offset_pow_n_inv, &z_pow_n, - &inv_denoms, + inv_denoms.as_deref().unwrap_or(&[]), + r3_arg, ); #[cfg(not(feature = "cuda"))] let main_gpu: Option>> = None; @@ -501,10 +531,12 @@ where let main_evals: Vec> = if let Some(v) = main_gpu { v } else { + let inv_denoms_v = + inv_denoms.get_or_insert_with(|| barycentric_inv_denoms(eval_point, &dc.points)); let col_scale = col_scale.get_or_insert_with(|| { dc.points .iter() - .zip(inv_denoms.iter()) + .zip(inv_denoms_v.iter()) .map(|(point, inv_d)| point * inv_d) .collect() }); @@ -532,6 +564,8 @@ where // GPU fast path for aux columns reading the de-interleaved ext3 LDE handle. #[cfg(feature = "cuda")] + let r3_arg_aux = r3_ctx.as_ref().map(|ctx| (ctx, eval_point_idx * 3 * n)); + #[cfg(feature = "cuda")] let aux_gpu = crate::gpu_lde::try_barycentric_ext3_on_handle::( lde_trace, bf, @@ -540,7 +574,8 @@ where &dc.size_inv, &dc.offset_pow_n_inv, &z_pow_n, - &inv_denoms, + inv_denoms.as_deref().unwrap_or(&[]), + r3_arg_aux, ); #[cfg(not(feature = "cuda"))] let aux_gpu: Option>> = None; @@ -548,10 +583,12 @@ where let aux_evals: Vec> = if let Some(v) = aux_gpu { v } else { + let inv_denoms_v = + inv_denoms.get_or_insert_with(|| barycentric_inv_denoms(eval_point, &dc.points)); let col_scale = col_scale.get_or_insert_with(|| { dc.points .iter() - .zip(inv_denoms.iter()) + .zip(inv_denoms_v.iter()) .map(|(point, inv_d)| point * inv_d) .collect() }); diff --git a/crypto/stark/tests/r4_denoms_parity.rs b/crypto/stark/tests/r4_denoms_parity.rs new file mode 100644 index 000000000..ad8284103 --- /dev/null +++ b/crypto/stark/tests/r4_denoms_parity.rs @@ -0,0 +1,114 @@ +//! R4 DEEP inverse-denominator parity: GPU `compute_and_invert_denoms_ext3_dev` +//! (with `DenomSign::XMinusZ`, the convention used by the prover's R4 DEEP +//! fast path) must match the CPU helper `build_r4_inv_denoms_cpu` that the +//! prover's CPU fallback also calls into. +//! +//! Pins the three-copy fragility flagged in PR review: kernel construction, +//! CPU fallback in prover.rs, and any test references must all be the same. +//! With this test, drift on either the helper or the kernel breaks the build. +//! +//! Requires the `cuda` feature. + +#![cfg(feature = "cuda")] + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::IsPrimeField; +use math_cuda::device::backend; +use math_cuda::inverse::{DenomSign, compute_and_invert_denoms_ext3_dev}; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use stark::r4_denoms::build_r4_inv_denoms_cpu; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +fn rand_fp(rng: &mut ChaCha8Rng) -> Fp { + Fp::from_raw(rng.r#gen::()) +} + +fn rand_fp3(rng: &mut ChaCha8Rng) -> Fp3 { + Fp3::new([rand_fp(rng), rand_fp(rng), rand_fp(rng)]) +} + +fn canon3(a: &[u64]) -> Vec { + a.iter().map(GoldilocksField::canonical).collect() +} + +fn ext3_to_u64s(col: &[Fp3]) -> Vec { + let mut out = Vec::with_capacity(col.len() * 3); + for e in col { + out.push(*e.value()[0].value()); + out.push(*e.value()[1].value()); + out.push(*e.value()[2].value()); + } + out +} + +fn run_parity(lde_size: usize, num_eval_points: usize, seed: u64) { + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let coset: Vec = (0..lde_size).map(|_| rand_fp(&mut rng)).collect(); + let z_power = rand_fp3(&mut rng); + let z_shifted: Vec = (0..num_eval_points).map(|_| rand_fp3(&mut rng)).collect(); + + // CPU side via the shared helper used by the prover's fallback. + let cpu = build_r4_inv_denoms_cpu::( + &coset, &z_power, &z_shifted, + ) + .expect("non-zero denoms"); + let cpu_u64 = canon3(&ext3_to_u64s(&cpu)); + + // GPU side via the device pipeline that the prover's fast path calls. + let be = backend().unwrap(); + let stream = be.next_stream(); + let coset_u64: Vec = coset.iter().map(|x| *x.value()).collect(); + let coset_dev = stream.clone_htod(&coset_u64).unwrap(); + let mut z_scalars: Vec = Vec::with_capacity(1 + num_eval_points); + z_scalars.push(z_power); + z_scalars.extend_from_slice(&z_shifted); + let z_u64 = ext3_to_u64s(&z_scalars); + let gpu_dev = compute_and_invert_denoms_ext3_dev( + &coset_dev, + &z_u64, + lde_size, + 1 + num_eval_points, + DenomSign::XMinusZ, + &stream, + ) + .unwrap(); + let gpu_u64 = canon3(&stream.clone_dtoh(&gpu_dev).unwrap()); + stream.synchronize().unwrap(); + + assert_eq!( + cpu_u64.len(), + gpu_u64.len(), + "length mismatch lde_size={lde_size} num_eval_points={num_eval_points}" + ); + for i in 0..(lde_size * (1 + num_eval_points)) { + let c = &cpu_u64[i * 3..(i + 1) * 3]; + let g = &gpu_u64[i * 3..(i + 1) * 3]; + assert_eq!( + c, + g, + "mismatch at flat={i} (k={}, idx={}) lde_size={lde_size} num_eval_points={num_eval_points}", + i / lde_size, + i % lde_size, + ); + } +} + +#[test] +#[ignore = "requires GPU; run with --ignored --nocapture"] +fn r4_denoms_parity_small() { + run_parity(1 << 14, 2, 1); + run_parity(1 << 14, 4, 2); +} + +#[test] +#[ignore = "requires GPU; run with --ignored --nocapture"] +fn r4_denoms_parity_prover_shape() { + // fib_iterative_1M / 4M LDE sizes with the common eval-point counts. + run_parity(1 << 18, 2, 100); + run_parity(1 << 20, 2, 101); +} diff --git a/prover/tests/cuda_fallback_tests.rs b/prover/tests/cuda_fallback_tests.rs index 0fc5ce172..00078d09f 100644 --- a/prover/tests/cuda_fallback_tests.rs +++ b/prover/tests/cuda_fallback_tests.rs @@ -14,7 +14,7 @@ use lambda_vm_prover::test_utils::asm_elf_bytes; use lambda_vm_prover::{prove, verify}; -use stark::gpu_lde::{gpu_fri_calls, reset_all_gpu_call_counters}; +use stark::gpu_lde::{gpu_batch_invert_calls, gpu_fri_calls, reset_all_gpu_call_counters}; /// FRI commit-phase CPU fallback: when the GPU dispatch errors after the /// first transcript mutation, `try_fri_commit_gpu` must restore the @@ -60,3 +60,43 @@ fn gpu_fri_fault_falls_back_to_cpu() { // Reset injection state for any subsequent tests in the same process. stark::gpu_lde::schedule_fri_fold_fault(-1); } + +/// Batch-invert CPU fallback: when `compute_and_invert_denoms_ext3_dev` +/// errors, `try_compute_and_invert_inv_denoms_dev` must return None so the +/// caller (R3 OOD in `trace.rs` or R4 DEEP in `prover.rs`) builds inv_denoms +/// on CPU and the remaining GPU path keeps running. +/// +/// The injection fires the Nth time the math-cuda entry point is reached, +/// across all tables. We assert that a single fault drops `gpu_batch_invert_calls` +/// by exactly one (one table fell back, the rest succeeded) and that the +/// recovered proof still verifies. +#[test] +#[ignore = "requires GPU + test-cuda-faults; run with --ignored --nocapture"] +fn gpu_batch_invert_fault_falls_back_to_cpu() { + let elf = asm_elf_bytes("fib_iterative_1M"); + reset_all_gpu_call_counters(); + let _ = prove(&elf).expect("warm-up"); + let clean = gpu_batch_invert_calls(); + assert!( + clean > 0, + "GPU batch-invert never ran, cannot test fallback" + ); + + for n in 1..=3i64 { + stark::gpu_lde::schedule_inverse_fault(n); + reset_all_gpu_call_counters(); + + let recovered = prove(&elf).expect("prove after fault"); + assert_eq!( + gpu_batch_invert_calls(), + clean - 1, + "expected exactly one GPU batch-invert fallback (fault #{n})" + ); + assert!( + verify(&recovered, &elf).expect("verify recovered"), + "post-fallback proof failed verification (batch-invert fault #{n})" + ); + } + + stark::gpu_lde::schedule_inverse_fault(-1); +} diff --git a/prover/tests/cuda_path_integration.rs b/prover/tests/cuda_path_integration.rs index 3653dd9a5..0f7c1f3c7 100644 --- a/prover/tests/cuda_path_integration.rs +++ b/prover/tests/cuda_path_integration.rs @@ -11,8 +11,8 @@ use lambda_vm_prover::test_utils::asm_elf_bytes; use lambda_vm_prover::{prove, verify}; use stark::gpu_lde::{ - gpu_bary_calls, gpu_comp_poly_tree_calls, gpu_deep_calls, gpu_fri_calls, gpu_lde_calls, - gpu_parts_lde_calls, reset_all_gpu_call_counters, + gpu_bary_calls, gpu_batch_invert_calls, gpu_comp_poly_tree_calls, gpu_deep_calls, + gpu_fri_calls, gpu_lde_calls, gpu_parts_lde_calls, reset_all_gpu_call_counters, }; #[test] @@ -53,6 +53,14 @@ fn gpu_path_fires_end_to_end() { // FRI commit fires once per table (commit_phase_from_evaluations). assert!(gpu_fri_calls() > 0, "R4 GPU FRI commit did not fire"); + // GPU batch-invert dispatch fires for the R3 OOD and R4 DEEP + // inv_denoms pipelines. A regression where either silently fell back + // to host inv_denoms would drop this to zero. + assert!( + gpu_batch_invert_calls() > 0, + "GPU batch-invert dispatch did not fire on R3 + R4" + ); + // Counters only prove the dispatches ran; this checks the GPU proof // actually satisfies the verifier. let ok = verify(&proof, &elf).expect("verify");