diff --git a/src/lib.rs b/src/lib.rs index 5138ce1..86edfeb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -101,6 +101,20 @@ pub enum Error { MaxTokensTooLow, } +/// Failure modes for [`Ratelimiter::try_wait_n`]. +#[derive(Error, Debug, Clone, Copy, PartialEq, Eq)] +pub enum TryWaitError { + /// Not enough tokens are available right now. Retry after the returned + /// duration — same semantics as the `Err` from [`Ratelimiter::try_wait`]. + #[error("insufficient tokens; retry after {0:?}")] + Insufficient(Duration), + /// `n` exceeds the bucket's current [`max_tokens`](Ratelimiter::max_tokens). + /// Waiting will never satisfy the request; the caller must reduce `n` or + /// increase `max_tokens`. + #[error("requested tokens exceed bucket capacity")] + ExceedsCapacity, +} + /// A lock-free token bucket ratelimiter. /// /// Tokens accumulate continuously based on elapsed time. Each `try_wait()` @@ -369,6 +383,57 @@ where core::hint::spin_loop(); } } + + /// Non-blocking attempt to atomically acquire `n` tokens. + /// + /// Like [`try_wait`](Ratelimiter::try_wait), but consumes `n` tokens in + /// a single atomic operation — partial consumption never occurs. + /// + /// `try_wait_n(0)` is a no-op and always returns `Ok(())`. When rate is + /// 0 (unlimited), always succeeds. + /// + /// # Errors + /// + /// - [`TryWaitError::Insufficient`] — not enough tokens right now; + /// retry after the returned duration. + /// - [`TryWaitError::ExceedsCapacity`] — `n` is greater than the + /// current [`max_tokens`](Ratelimiter::max_tokens). The bucket can + /// never hold enough, so waiting will not help; reduce `n` or + /// increase `max_tokens`. + pub fn try_wait_n(&self, n: u64) -> Result<(), TryWaitError> { + let rate = self.rate.load(Ordering::Relaxed); + if rate == 0 { + return Ok(()); + } + if n == 0 { + return Ok(()); + } + + if n > self.max_tokens.load(Ordering::Relaxed) { + return Err(TryWaitError::ExceedsCapacity); + } + + self.refill(); + + let cost = n.saturating_mul(TOKEN_SCALE); + loop { + let current = self.tokens.load(Ordering::Acquire); + if current < cost { + let deficit = cost - current; + let wait_ns = (deficit as u128 * 1_000 / rate as u128).max(1) as u64; + return Err(TryWaitError::Insufficient(Duration::from_nanos(wait_ns))); + } + + if self + .tokens + .compare_exchange_weak(current, current - cost, Ordering::AcqRel, Ordering::Relaxed) + .is_ok() + { + return Ok(()); + } + core::hint::spin_loop(); + } + } } const _: () = { @@ -790,4 +855,60 @@ mod tests { assert!(total >= 1000, "expected >= 1000, got {total}"); assert!(total <= 4000, "expected <= 4000, got {total}"); } + + #[test] + fn try_wait_n_basic() { + let clock = TestClock::new(); + let rl = Builder::with_clock(1000, clock) + .initial_available(10) + .build() + .unwrap(); + assert!(rl.try_wait_n(5).is_ok()); + assert!(rl.try_wait_n(5).is_ok()); + assert!(matches!( + rl.try_wait_n(1), + Err(TryWaitError::Insufficient(_)) + )); + } + + #[test] + fn try_wait_n_zero_is_noop() { + let rl = Ratelimiter::with_clock(1000, TestClock::new()); + // No tokens yet, but n=0 still succeeds and consumes nothing. + assert!(rl.try_wait_n(0).is_ok()); + assert_eq!(rl.available(), 0); + } + + #[test] + fn try_wait_n_unlimited() { + let rl = Ratelimiter::with_clock(0, TestClock::new()); + assert!(rl.try_wait_n(1_000_000).is_ok()); + } + + #[test] + fn try_wait_n_does_not_partially_consume() { + let clock = TestClock::new(); + let rl = Builder::with_clock(1000, clock) + .initial_available(5) + .build() + .unwrap(); + // Asking for more than available (but <= max_tokens) must fail without consuming. + assert!(matches!( + rl.try_wait_n(10), + Err(TryWaitError::Insufficient(_)) + )); + for _ in 0..5 { + assert!(rl.try_wait().is_ok()); + } + } + + #[test] + fn try_wait_n_exceeds_capacity() { + let clock = TestClock::new(); + let rl = Builder::with_clock(1000, clock) + .max_tokens(10) + .build() + .unwrap(); + assert_eq!(rl.try_wait_n(100), Err(TryWaitError::ExceedsCapacity)); + } }