Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
d1a0abf
add first cuda files
ColoCarletti May 6, 2026
79634ff
fmt
ColoCarletti May 6, 2026
ac6fbb5
fix clippy
ColoCarletti May 6, 2026
2ceb3b0
gpu 2nd part
ColoCarletti May 6, 2026
affceb1
feat(cuda): Round 1 GPU LDE+commit dispatch + device-resident handles
ColoCarletti May 6, 2026
01172f2
merge main
ColoCarletti May 19, 2026
c4627e1
Merge branch 'main' into feat/cuda-pr2-r1-gpu-commits
ColoCarletti May 19, 2026
01aa5e4
comments fix
ColoCarletti May 20, 2026
cfc5c19
Merge branch 'main' into feat/cuda-pr2-r1-gpu-commits
MauroToscano May 21, 2026
ea5696f
Update crypto/stark/src/gpu_lde.rs
ColoCarletti May 21, 2026
a8cf265
Update crypto/stark/src/gpu_lde.rs
ColoCarletti May 21, 2026
fb8d31f
Update crypto/stark/src/gpu_lde.rs
ColoCarletti May 21, 2026
a79f2b5
Update crypto/stark/src/gpu_lde.rs
ColoCarletti May 21, 2026
761a2c0
Update crypto/stark/src/gpu_lde.rs
ColoCarletti May 21, 2026
e066e9d
address reviews
ColoCarletti May 21, 2026
7d3d0f0
fix review comments
ColoCarletti May 22, 2026
cf80771
Merge remote-tracking branch 'origin/main' into feat/cuda-pr2-r1-gpu-…
ColoCarletti May 22, 2026
71aba0d
address doc comment suggestions
ColoCarletti May 22, 2026
83d91b8
Merge branch 'main' into feat/cuda-pr2-r1-gpu-commits
ColoCarletti May 22, 2026
34cae4b
fix
ColoCarletti May 22, 2026
f076bf4
Merge branch 'main' into feat/cuda-pr2-r1-gpu-commits
gabrielbosio May 27, 2026
a2cde0f
Pass replay transcript to bus-balance call in verify_vm_minimal
gabrielbosio May 27, 2026
46c305b
Update crypto/math-cuda/src/device.rs
ColoCarletti May 28, 2026
aca3dca
Merge branch 'main' into feat/cuda-pr2-r1-gpu-commits
ColoCarletti May 28, 2026
63d7c00
Update crypto/math-cuda/src/device.rs
ColoCarletti May 29, 2026
eb16c02
Update crypto/math-cuda/src/device.rs
ColoCarletti May 29, 2026
66925b1
Update crypto/math-cuda/src/device.rs
ColoCarletti May 29, 2026
4e6daf3
Update crypto/math-cuda/src/lde.rs
ColoCarletti May 29, 2026
4cd27d9
Update crypto/math-cuda/src/lde.rs
ColoCarletti May 29, 2026
5fe390f
Update crypto/math-cuda/src/lde.rs
ColoCarletti May 29, 2026
5819930
Update crypto/math-cuda/src/lde.rs
ColoCarletti May 29, 2026
33f7c36
Update crypto/math-cuda/src/lde.rs
ColoCarletti May 29, 2026
49d3607
Merge branch 'main' into feat/cuda-pr2-r1-gpu-commits
ColoCarletti May 29, 2026
99cd59c
add pr3 code
ColoCarletti Jun 1, 2026
c52521e
Merge branch 'main' into feat/cuda-pr2-r1-gpu-commits
ColoCarletti Jun 1, 2026
828ee16
fix comments
ColoCarletti Jun 1, 2026
19a36a0
Merge remote-tracking branch 'origin/feat/cuda-pr2-r1-gpu-commits' in…
ColoCarletti Jun 1, 2026
80e1ecb
fix sync stream after D2H in merke.rs
ColoCarletti Jun 1, 2026
3ead022
Merge branch 'main' into feat/cuda-pr3
ColoCarletti Jun 1, 2026
04dd872
fix comments
ColoCarletti Jun 1, 2026
8a67e33
address review feedback
ColoCarletti Jun 1, 2026
1f9394d
Update crypto/math-cuda/src/barycentric.rs
ColoCarletti Jun 1, 2026
b07999c
Update crypto/math-cuda/src/barycentric.rs
ColoCarletti Jun 1, 2026
c575017
fix imports
ColoCarletti Jun 1, 2026
0ffc661
Merge branch 'feat/cuda-pr3' of github.com:yetanotherco/lambda_vm int…
ColoCarletti Jun 1, 2026
0777f1e
Merge branch 'main' into feat/cuda-pr3
ColoCarletti Jun 3, 2026
2c7b0de
cuda integration tests
ColoCarletti Jun 3, 2026
2f1fe2d
address review feedback
ColoCarletti Jun 3, 2026
f254eae
batch invert kernels and parity test
ColoCarletti Jun 3, 2026
84cc04b
DEEP composition kernel
ColoCarletti Jun 3, 2026
0ba7745
fri
ColoCarletti Jun 3, 2026
7046a40
gpu lde
ColoCarletti Jun 3, 2026
065c8f9
gpu_lde
ColoCarletti Jun 3, 2026
7d2810f
fri
ColoCarletti Jun 3, 2026
cc840cd
add tests
ColoCarletti Jun 3, 2026
fac3974
fix
ColoCarletti Jun 3, 2026
bc61a00
Merge branch 'main' into feat/cuda-pr4
ColoCarletti Jun 3, 2026
3c52fdf
fix comments
ColoCarletti Jun 5, 2026
59437f3
add integration tests
ColoCarletti Jun 5, 2026
c499ee0
fix comments
ColoCarletti Jun 8, 2026
025813a
refactor test
ColoCarletti Jun 8, 2026
f41bb7b
rm dead code, refactor
ColoCarletti Jun 8, 2026
6399cf2
fix
ColoCarletti Jun 8, 2026
b8d97d5
rm doc
ColoCarletti Jun 8, 2026
6f3262d
gpu batch inverse
ColoCarletti Jun 9, 2026
b422d71
fix
ColoCarletti Jun 9, 2026
95b8025
Merge branch 'feat/cuda-pr4' of github.com:yetanotherco/lambda_vm int…
ColoCarletti Jun 9, 2026
b706e48
fallback test
ColoCarletti Jun 9, 2026
5eae98a
Merge branch 'main' into feat/cuda-pr4
ColoCarletti Jun 10, 2026
578cb29
fix_comments
ColoCarletti Jun 10, 2026
50d2541
Merge remote-tracking branch 'origin/feat/cuda-pr4' into feat/cuda-pr…
ColoCarletti Jun 10, 2026
84ae125
cleanup
ColoCarletti Jun 10, 2026
ca4efc7
Merge remote-tracking branch 'origin/main' into feat/cuda-pr5-batch-i…
ColoCarletti Jun 10, 2026
adbcfe2
fmt
ColoCarletti Jun 10, 2026
7386d0a
Merge branch 'main' into feat/cuda-pr5-batch-invert
MauroToscano Jun 11, 2026
a29b013
Merge branch 'main' into feat/cuda-pr5-batch-invert
ColoCarletti Jun 12, 2026
1a1de35
address comments
ColoCarletti Jun 12, 2026
d73219e
harden inv_denoms guard, fix scan kernel race
ColoCarletti Jun 12, 2026
b0c60e1
fix debug assert
ColoCarletti Jun 12, 2026
577c6b2
Merge branch 'main' into feat/cuda-pr5-batch-invert
diegokingston Jun 12, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crypto/math-cuda/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,5 @@ fn main() {
compile_ptx("barycentric.cu", "barycentric.ptx", have_nvcc);
compile_ptx("deep.cu", "deep.ptx", have_nvcc);
compile_ptx("fri.cu", "fri.ptx", have_nvcc);
compile_ptx("inverse.cu", "inverse.ptx", have_nvcc);
}
313 changes: 313 additions & 0 deletions crypto/math-cuda/kernels/inverse.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
// Parallel Montgomery batch inverse over ext3 elements.
//
// Algorithm: given a[0..N-1] all non-zero, compute a^{-1}[0..N-1] using
// prefix[i] = a[0] * a[1] * ... * a[i] (inclusive forward scan)
// suffix[i] = a[i] * a[i+1] * ... * a[N-1] (inclusive backward scan)
// total = prefix[N-1] = suffix[0]
// inv_total = 1 / total (one Fermat inversion on host)
// a^{-1}[i] = prefix[i-1] * inv_total * suffix[i+1] (boundaries use identity)
//
// Each scan is a multi-block 3-phase Hillis-Steele scan in shared memory:
// Phase 1: each block does an inclusive scan over its 256 elements and
// writes its block sum to a per-block totals array.
// Phase 2: recursively scan the block totals (host re-launches this same
// kernel set; recursion depth = ceil(log_256(N))).
// Phase 3: each block reads its offset (the inclusive prefix of all
// preceding block sums) and multiplies it into every element.
//
// Forward and backward kernels are mirrors of each other.
//
// Buffer layouts: all ext3 buffers are interleaved [a0,b0,c0, a1,b1,c1, ...]
// with one u64 per coordinate. `BLOCK_SIZE = 256` ext3 elements per block
// uses 6 KB of shared memory, well under the per-SM limit on Ada/Blackwell.

#include "goldilocks.cuh"
#include "ext3.cuh"

#define BLOCK_SIZE 256

// ---------------------------------------------------------------------------
// 1. compute_denoms_ext3
//
// If `subtract_x = 0` (R3 OOD convention): denoms[k * n + i] = z[k] - x[i].
// Matches CPU `barycentric_inv_denoms(z, points)` = 1/(z - points[i]).
// If `subtract_x = 1` (R4 DEEP convention): denoms[k * n + i] = x[i] - z[k].
// Matches CPU R4 `denoms.push(x_i - z_k)` convention.
//
// Output is ext3-interleaved of length 3 * k_scalars * n.
//
// Launched as grid = ceil(total / BLOCK_SIZE), where total = k_scalars * n.
// Each thread builds one denom.
// ---------------------------------------------------------------------------
extern "C" __global__ void compute_denoms_ext3(
const uint64_t *x_base, // n u64
const uint64_t *z_scalars, // 3 * k_scalars u64
uint64_t n,
uint64_t k_scalars,
uint64_t subtract_x, // 0: z - x; 1: x - z
uint64_t *denoms_out // 3 * k_scalars * n u64
) {
uint64_t flat = (uint64_t)blockIdx.x * BLOCK_SIZE + threadIdx.x;
uint64_t total = k_scalars * n;
if (flat >= total) return;

uint64_t k = flat / n;
uint64_t i = flat - k * n;

uint64_t x_i = x_base[i];
ext3::Fe3 z = {
z_scalars[k * 3 + 0],
z_scalars[k * 3 + 1],
z_scalars[k * 3 + 2],
};
ext3::Fe3 d;
if (subtract_x == 0) {
// z - x: lift x to (x, 0, 0), subtract from z.
d.a = goldilocks::sub(z.a, x_i);
d.b = z.b;
d.c = z.c;
} else {
// x - z: lift x to (x, 0, 0), subtract z.
d.a = goldilocks::sub(x_i, z.a);
d.b = goldilocks::neg(z.b);
d.c = goldilocks::neg(z.c);
}

denoms_out[flat * 3 + 0] = d.a;
denoms_out[flat * 3 + 1] = d.b;
denoms_out[flat * 3 + 2] = d.c;
}

// ---------------------------------------------------------------------------
// 2. block_inclusive_scan_fwd_ext3
//
// Per-block forward Hillis-Steele inclusive scan with multiplication. Writes
// scan_out[gid] = product of input[block_start..=gid] and block_totals[bid] =
// the product over the entire block.
//
// Threads handle out-of-range positions by loading the identity element (1),
// so a partial last block still produces a correct scan.
// ---------------------------------------------------------------------------
extern "C" __global__ void block_inclusive_scan_fwd_ext3(
const uint64_t *input, // 3 * n u64
uint64_t n,
uint64_t *scan_out, // 3 * n u64
uint64_t *block_totals // 3 * K u64, K = ceil(n / BLOCK_SIZE)
) {
__shared__ ext3::Fe3 shmem[BLOCK_SIZE];
uint64_t tid = threadIdx.x;
uint64_t gid = (uint64_t)blockIdx.x * BLOCK_SIZE + tid;

// Load input or identity.
if (gid < n) {
shmem[tid].a = input[gid * 3 + 0];
shmem[tid].b = input[gid * 3 + 1];
shmem[tid].c = input[gid * 3 + 2];
} else {
shmem[tid] = ext3::one();
}
__syncthreads();

// Hillis-Steele inclusive scan: 8 doubling levels for BLOCK_SIZE = 256.
for (uint32_t offset = 1; offset < BLOCK_SIZE; offset <<= 1) {
ext3::Fe3 prev = (tid >= offset) ? shmem[tid - offset] : ext3::one();
__syncthreads();
if (tid >= offset) {
shmem[tid] = ext3::mul(prev, shmem[tid]);
}
__syncthreads();
}

// Write per-element scan result.
if (gid < n) {
scan_out[gid * 3 + 0] = shmem[tid].a;
scan_out[gid * 3 + 1] = shmem[tid].b;
scan_out[gid * 3 + 2] = shmem[tid].c;
}

// Block total = scan value at the last VALID thread of this block.
// The last valid gid in this block is min(block_end - 1, n - 1).
// Computing it explicitly (instead of `tid == 255 || gid == n - 1`)
// ensures EXACTLY ONE thread writes per block — in a partial last
// block the two conditions would otherwise both fire and race.
uint64_t block_end = ((uint64_t)blockIdx.x + 1) * BLOCK_SIZE;
uint64_t last_valid_gid = (block_end - 1 < n - 1) ? (block_end - 1) : (n - 1);
if (gid == last_valid_gid) {
block_totals[(uint64_t)blockIdx.x * 3 + 0] = shmem[tid].a;
block_totals[(uint64_t)blockIdx.x * 3 + 1] = shmem[tid].b;
block_totals[(uint64_t)blockIdx.x * 3 + 2] = shmem[tid].c;
}
}

// ---------------------------------------------------------------------------
// 3. apply_block_offsets_fwd_ext3
//
// Phase 3 of the forward scan: each block b > 0 multiplies its per-block
// scan by `block_totals_scanned[b-1]` (the inclusive prefix of preceding
// block totals). Block 0 has no offset, so it returns early.
// ---------------------------------------------------------------------------
extern "C" __global__ void apply_block_offsets_fwd_ext3(
uint64_t *scan_inout, // 3 * n u64 (modified in place)
uint64_t n,
const uint64_t *block_totals_scanned // 3 * K u64, inclusive prefix of phase-1 totals
) {
if (blockIdx.x == 0) return;
uint64_t tid = threadIdx.x;
uint64_t gid = (uint64_t)blockIdx.x * BLOCK_SIZE + tid;
if (gid >= n) return;

ext3::Fe3 offset = {
block_totals_scanned[(blockIdx.x - 1) * 3 + 0],
block_totals_scanned[(blockIdx.x - 1) * 3 + 1],
block_totals_scanned[(blockIdx.x - 1) * 3 + 2],
};
ext3::Fe3 val = {
scan_inout[gid * 3 + 0],
scan_inout[gid * 3 + 1],
scan_inout[gid * 3 + 2],
};
ext3::Fe3 res = ext3::mul(offset, val);
scan_inout[gid * 3 + 0] = res.a;
scan_inout[gid * 3 + 1] = res.b;
scan_inout[gid * 3 + 2] = res.c;
}

// ---------------------------------------------------------------------------
// 4. block_inclusive_scan_rev_ext3
//
// Mirror of `block_inclusive_scan_fwd_ext3` for the suffix product:
// suffix[i] = input[i] * input[i+1] * ... * input[n-1]
//
// Block b processes pos_from_end in [b*B, (b+1)*B), where gid = n-1-pos_from_end.
// Inside shmem the order is reversed so a forward Hillis-Steele scan over
// the loaded values produces the suffix scan in the original index space.
// ---------------------------------------------------------------------------
extern "C" __global__ void block_inclusive_scan_rev_ext3(
const uint64_t *input,
uint64_t n,
uint64_t *scan_out,
uint64_t *block_totals
) {
__shared__ ext3::Fe3 shmem[BLOCK_SIZE];
uint64_t tid = threadIdx.x;
uint64_t pos_from_end = (uint64_t)blockIdx.x * BLOCK_SIZE + tid;
bool valid = pos_from_end < n;
uint64_t gid = valid ? (n - 1 - pos_from_end) : 0;

if (valid) {
shmem[tid].a = input[gid * 3 + 0];
shmem[tid].b = input[gid * 3 + 1];
shmem[tid].c = input[gid * 3 + 2];
} else {
shmem[tid] = ext3::one();
}
__syncthreads();

for (uint32_t offset = 1; offset < BLOCK_SIZE; offset <<= 1) {
ext3::Fe3 prev = (tid >= offset) ? shmem[tid - offset] : ext3::one();
__syncthreads();
if (tid >= offset) {
shmem[tid] = ext3::mul(prev, shmem[tid]);
}
__syncthreads();
}

if (valid) {
scan_out[gid * 3 + 0] = shmem[tid].a;
scan_out[gid * 3 + 1] = shmem[tid].b;
scan_out[gid * 3 + 2] = shmem[tid].c;
}

// Mutually-exclusive last-thread mask (same idea as fwd): the last
// valid pos_from_end in this block is min(block_end - 1, n - 1).
uint64_t block_end_rev = ((uint64_t)blockIdx.x + 1) * BLOCK_SIZE;
uint64_t last_valid_pos = (block_end_rev - 1 < n - 1) ? (block_end_rev - 1) : (n - 1);
if (pos_from_end == last_valid_pos) {
block_totals[(uint64_t)blockIdx.x * 3 + 0] = shmem[tid].a;
block_totals[(uint64_t)blockIdx.x * 3 + 1] = shmem[tid].b;
block_totals[(uint64_t)blockIdx.x * 3 + 2] = shmem[tid].c;
}
}

// ---------------------------------------------------------------------------
// 5. apply_block_offsets_rev_ext3
//
// Phase 3 of the suffix scan. Block b > 0 multiplies its per-block scan
// by the inclusive prefix of block totals from blocks [0..b-1] (which, in
// the reverse-block indexing, correspond to the indices LARGER than this
// block's gids).
// ---------------------------------------------------------------------------
extern "C" __global__ void apply_block_offsets_rev_ext3(
uint64_t *scan_inout,
uint64_t n,
const uint64_t *block_totals_scanned
) {
if (blockIdx.x == 0) return;
uint64_t tid = threadIdx.x;
uint64_t pos_from_end = (uint64_t)blockIdx.x * BLOCK_SIZE + tid;
if (pos_from_end >= n) return;
uint64_t gid = n - 1 - pos_from_end;

ext3::Fe3 offset = {
block_totals_scanned[(blockIdx.x - 1) * 3 + 0],
block_totals_scanned[(blockIdx.x - 1) * 3 + 1],
block_totals_scanned[(blockIdx.x - 1) * 3 + 2],
};
ext3::Fe3 val = {
scan_inout[gid * 3 + 0],
scan_inout[gid * 3 + 1],
scan_inout[gid * 3 + 2],
};
ext3::Fe3 res = ext3::mul(offset, val);
scan_inout[gid * 3 + 0] = res.a;
scan_inout[gid * 3 + 1] = res.b;
scan_inout[gid * 3 + 2] = res.c;
}

// ---------------------------------------------------------------------------
// 6. batch_inverse_combine_ext3
//
// out[i] = prefix[i-1] * inv_total * suffix[i+1]
//
// Boundaries: prefix[-1] = identity, suffix[n] = identity.
// inv_total = 1 / (prefix[n-1]) = 1 / (suffix[0]); the caller computes it
// on host via Fermat's little theorem (one extension-field inverse per
// batch) and uploads as a 3 * u64 device buffer.
// ---------------------------------------------------------------------------
extern "C" __global__ void batch_inverse_combine_ext3(
const uint64_t *prefix, // 3 * n u64
const uint64_t *suffix, // 3 * n u64
const uint64_t *inv_total, // 3 u64
uint64_t n,
uint64_t *out // 3 * n u64
) {
uint64_t i = (uint64_t)blockIdx.x * BLOCK_SIZE + threadIdx.x;
if (i >= n) return;

ext3::Fe3 inv_t = {inv_total[0], inv_total[1], inv_total[2]};

ext3::Fe3 p;
if (i == 0) {
p = ext3::one();
} else {
p.a = prefix[(i - 1) * 3 + 0];
p.b = prefix[(i - 1) * 3 + 1];
p.c = prefix[(i - 1) * 3 + 2];
}

ext3::Fe3 s;
if (i == n - 1) {
s = ext3::one();
} else {
s.a = suffix[(i + 1) * 3 + 0];
s.b = suffix[(i + 1) * 3 + 1];
s.c = suffix[(i + 1) * 3 + 2];
}

ext3::Fe3 tmp = ext3::mul(p, inv_t);
ext3::Fe3 res = ext3::mul(tmp, s);

out[i * 3 + 0] = res.a;
out[i * 3 + 1] = res.b;
out[i * 3 + 2] = res.c;
}
Loading
Loading