diff --git a/README.md b/README.md index ee164a3..059652a 100644 --- a/README.md +++ b/README.md @@ -6,9 +6,13 @@ [![Rustc Version 1.85.0+](https://img.shields.io/badge/rustc-1.85.0+-lightgray.svg)](https://blog.rust-lang.org/2025/02/20/Rust-1.85.0/) [![CI](https://github.com/M4SS-Code/massping/actions/workflows/ci.yml/badge.svg)](https://github.com/M4SS-Code/massping/actions/workflows/ci.yml) -Asynchronous ICMP ping library using Linux RAW sockets and the +Asynchronous ICMP ping library using Linux DGRAM sockets and the tokio runtime. +This crate uses `SOCK_DGRAM` sockets with `IPPROTO_ICMP`/`IPPROTO_ICMPV6` +("ping sockets"), which allows sending ICMP echo requests without root +privileges on Linux. + ## Features * `stream`: implements `Stream` for `MeasureManyStream`. diff --git a/src/lib.rs b/src/lib.rs index 33b4a9f..55d331a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -48,6 +48,10 @@ pub mod raw_pinger; mod socket; /// A pinger for both [`Ipv4Addr`] and [`Ipv6Addr`] addresses. +/// +/// Cloning is cheap: clones share the same sockets and background +/// receive tasks, which shut down when the last clone is dropped. +#[derive(Clone)] pub struct DualstackPinger { v4: V4Pinger, v6: V6Pinger, @@ -57,9 +61,14 @@ impl DualstackPinger { /// Construct a new `DualstackPinger`. /// /// For maximum efficiency the same instance of `DualstackPinger` should - /// be used for as long as possible, altough it might also + /// be used for as long as possible, although it might also /// be beneficial to `Drop` the `DualstackPinger` and recreate it if /// you are not going to be sending pings for a long period of time. + /// + /// # Panics + /// + /// Panics if called from outside a tokio runtime, as it spawns + /// background receive tasks. pub fn new() -> io::Result { let v4 = V4Pinger::new()?; let v6 = V6Pinger::new()?; @@ -71,6 +80,10 @@ impl DualstackPinger { /// Creates [`DualstackMeasureManyStream`] which **lazily** sends ping /// requests and [`Stream`]s the responses as they arrive. /// + /// # Panics + /// + /// See [`Pinger::measure_many`]. + /// /// [`Stream`]: futures_core::Stream pub fn measure_many(&self, addresses: I) -> DualstackMeasureManyStream<'_, I> where @@ -100,10 +113,12 @@ impl DualstackPinger { /// like [`tokio::time::timeout`] should be used to prevent the program /// from hanging indefinitely. /// -/// Leaking this method might crate a slowly forever growing memory leak. +/// Leaking this stream may create a memory leak that lasts until the +/// [`DualstackPinger`] is dropped. /// /// [`Stream`]: futures_core::Stream /// [`tokio::time::timeout`]: tokio::time::timeout +#[must_use = "streams do nothing unless polled"] pub struct DualstackMeasureManyStream<'a, I: Iterator> { v4: MeasureManyStream<'a, Ipv4Addr, FilterIpAddr>, v6: MeasureManyStream<'a, Ipv6Addr, FilterIpAddr>, diff --git a/src/packet.rs b/src/packet.rs index 1ea389a..1ecda62 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -38,6 +38,13 @@ pub struct EchoReplyPacket { impl EchoRequestPacket { /// Build a new ICMP echo request packet + /// + /// Note that when the packet is sent through a Linux `SOCK_DGRAM` ICMP + /// socket ("ping socket"), as done by this crate, the kernel overwrites + /// `identifier` with the socket's own identifier and recomputes the + /// checksum. The kernel also delivers only the echo replies whose + /// identifier matches the socket, so replies don't need to be checked + /// against the value given here. pub fn new(identifier: u16, sequence_number: u16, payload: &[u8]) -> Self { let mut buf = BytesMut::zeroed(ICMP_HEADER_LEN + payload.len()); @@ -66,6 +73,11 @@ impl EchoRequestPacket { pub(crate) fn as_bytes(&self) -> &[u8] { &self.buf } + + /// Get the payload of this echo request. + pub(crate) fn payload(&self) -> Bytes { + self.buf.slice(ICMP_HEADER_LEN..) + } } impl EchoReplyPacket { @@ -103,6 +115,9 @@ impl EchoReplyPacket { } /// Get the ICMP packet identifier + /// + /// On Linux ping sockets this is the kernel-assigned identifier of the + /// receiving socket, not the value passed to [`EchoRequestPacket::new`]. pub fn identifier(&self) -> u16 { self.identifier } @@ -141,11 +156,71 @@ fn internet_checksum(data: &[u8]) -> u16 { #[cfg(test)] mod tests { - use std::net::Ipv4Addr; + use std::net::{Ipv4Addr, Ipv6Addr}; use bytes::Bytes; - use super::EchoReplyPacket; + use super::{EchoReplyPacket, EchoRequestPacket, internet_checksum}; + + /// Well-known example from the "Internet checksum" Wikipedia article: + /// an IPv4 header (checksum field zeroed) whose checksum is 0xB861. + #[test] + fn internet_checksum_reference_vector() { + let ip_header = [ + 0x45, 0x00, 0x00, 0x73, 0x00, 0x00, 0x40, 0x00, 0x40, 0x11, 0x00, 0x00, 0xc0, 0xa8, + 0x00, 0x01, 0xc0, 0xa8, 0x00, 0xc7, + ]; + assert_eq!(internet_checksum(&ip_header), 0xb861); + } + + #[test] + fn internet_checksum_empty() { + assert_eq!(internet_checksum(&[]), 0xffff); + } + + /// The end-around carry must be folded back into the sum. + #[test] + fn internet_checksum_folds_carry() { + assert_eq!(internet_checksum(&[0xff; 8]), 0x0000); + } + + /// The trailing byte of odd-length data is padded with a zero byte, + /// i.e. it forms the high-order byte of the last 16-bit word. + #[test] + fn internet_checksum_odd_length() { + assert_eq!(internet_checksum(&[0x01, 0x02, 0x03]), !(0x0102 + 0x0300)); + } + + #[test] + fn echo_request_packet_v4_layout() { + let packet = EchoRequestPacket::::new(0x1234, 0x5678, b"test"); + let buf = packet.as_bytes(); + + assert_eq!(buf.len(), 12); + assert_eq!(buf[0], 8, "ICMPv4 echo request type"); + assert_eq!(buf[1], 0, "code"); + assert_eq!(buf[2..4], 0xa779u16.to_be_bytes(), "checksum"); + assert_eq!(buf[4..6], 0x1234u16.to_be_bytes(), "identifier"); + assert_eq!(buf[6..8], 0x5678u16.to_be_bytes(), "sequence number"); + assert_eq!(&buf[8..], b"test"); + + // A packet whose checksum field is correct sums to zero. + assert_eq!(internet_checksum(buf), 0); + assert_eq!(&packet.payload()[..], b"test"); + } + + #[test] + fn echo_request_packet_v6_uses_icmpv6_type() { + // Odd-length payload to also exercise the checksum padding. + let packet = EchoRequestPacket::::new(1, 2, b"abc"); + let buf = packet.as_bytes(); + + assert_eq!(buf[0], 128, "ICMPv6 echo request type"); + // The kernel recomputes the ICMPv6 checksum (it includes a + // pseudo-header userspace can't know), but the packet must still be + // self-consistent under the plain internet checksum. + assert_eq!(internet_checksum(buf), 0); + } #[test] fn from_reply_rejects_truncated_packet() { diff --git a/src/pinger.rs b/src/pinger.rs index 670b80d..d097fd3 100644 --- a/src/pinger.rs +++ b/src/pinger.rs @@ -8,13 +8,13 @@ use std::{ net::{Ipv4Addr, Ipv6Addr}, sync::{ Arc, - atomic::{AtomicU16, Ordering}, + atomic::{AtomicU64, Ordering}, }, task::{Context, Poll}, time::Duration, }; -use bytes::BytesMut; +use bytes::{Bytes, BytesMut}; #[cfg(feature = "stream")] use futures_core::Stream; use tokio::{ @@ -30,27 +30,51 @@ pub type V4Pinger = Pinger; pub type V6Pinger = Pinger; /// A pinger for [`IpVersion`] (either [`Ipv4Addr`] or [`Ipv6Addr`]). +/// +/// Cloning is cheap: clones share the same socket and background +/// receive task, which shut down when the last clone is dropped. pub struct Pinger { inner: Arc>, + // Kept out of `InnerPinger` (which the background receive task holds) + // so that dropping the last `Pinger` clone disconnects the channel, + // telling the background task to shut down and release the socket. + round_sender: mpsc::UnboundedSender>, +} + +impl Clone for Pinger { + fn clone(&self) -> Self { + Self { + inner: Arc::clone(&self.inner), + round_sender: self.round_sender.clone(), + } + } } struct InnerPinger { raw: RawPinger, - round_sender: mpsc::UnboundedSender>, - identifier: u16, - sequence_number: AtomicU16, + next_round_id: AtomicU64, } +// Each `measure_many` round gets a unique `u64` id; the wire sequence +// number is its lower 16 bits. The full id lets the receive task tell +// rounds apart after the sequence number wraps around. enum RoundMessage { Subscribe { - sequence_number: u16, + round_id: u64, + expected_payload: Bytes, sender: mpsc::UnboundedSender<(V, Instant)>, }, Unsubscribe { - sequence_number: u16, + round_id: u64, }, } +struct Subscriber { + round_id: u64, + expected_payload: Bytes, + sender: mpsc::UnboundedSender<(V, Instant)>, +} + enum PollResult { Subscription(RoundMessage), Packet(crate::packet::EchoReplyPacket), @@ -60,27 +84,30 @@ impl Pinger { /// Construct a new `Pinger`. /// /// For maximum efficiency the same instance of `Pinger` should - /// be used for as long as possible, altough it might also + /// be used for as long as possible, although it might also /// be beneficial to `Drop` the `Pinger` and recreate it if /// you are not going to be sending pings for a long period of time. + /// + /// # Panics + /// + /// Panics if called from outside a tokio runtime, as it spawns a + /// background receive task. pub fn new() -> io::Result { let raw = RawPinger::new()?; - let identifier = rand::random::(); - let (sender, mut receiver) = mpsc::unbounded_channel(); let inner = Arc::new(InnerPinger { raw, - round_sender: sender, - identifier, - sequence_number: AtomicU16::new(0), + next_round_id: AtomicU64::new(0), }); - // Spawn async receive task using the same socket + // Spawn async receive task using the same socket. + // It runs until `receiver` disconnects, which happens when the + // `Pinger` holding the only sender is dropped. let inner_recv = Arc::clone(&inner); tokio::spawn(async move { - let mut subscribers: HashMap> = HashMap::new(); + let mut subscribers: HashMap> = HashMap::new(); // Buffer kept outside poll_fn so it persists across polls. let mut recv_buf = BytesMut::new(); @@ -103,10 +130,21 @@ impl Pinger { } // Try to receive an ICMP packet - if let Poll::Ready(Ok(packet)) = inner_recv.raw.poll_recv(&mut recv_buf, cx) { - return Poll::Ready(Some(PollResult::Packet(packet))); + match inner_recv.raw.poll_recv(&mut recv_buf, cx) { + Poll::Ready(Ok(packet)) => { + return Poll::Ready(Some(PollResult::Packet(packet))); + } + Poll::Ready(Err(_)) => { + // Receiving failed (typically a transient kernel + // resource error). The socket readiness was + // consumed without registering a waker, so ask to + // be polled again right away; parking here would + // suspend reply processing until an unrelated + // subscription message wakes the task. + cx.waker().wake_by_ref(); + } + Poll::Pending => {} } - // Socket error or not ready - continue polling // Register waker for subscription channel // We need to wake up when new subscriptions arrive @@ -124,15 +162,35 @@ impl Pinger { match result { Some(PollResult::Subscription(RoundMessage::Subscribe { - sequence_number, + round_id, + expected_payload, sender, })) => { - subscribers.insert(sequence_number, sender); + // A new round may displace a still-subscribed round + // whose sequence number collided after wraparound; + // the displaced round could not be served anyway as + // replies can only be told apart by sequence number. + subscribers.insert( + round_id as u16, + Subscriber { + round_id, + expected_payload, + sender, + }, + ); } - Some(PollResult::Subscription(RoundMessage::Unsubscribe { - sequence_number, - })) => { - subscribers.remove(&sequence_number); + Some(PollResult::Subscription(RoundMessage::Unsubscribe { round_id })) => { + let sequence_number = round_id as u16; + // Only unsubscribe if the slot still belongs to this + // round: after sequence number wraparound it may have + // been taken over by a newer round, which must keep + // receiving replies. + if subscribers + .get(&sequence_number) + .is_some_and(|subscriber| subscriber.round_id == round_id) + { + subscribers.remove(&sequence_number); + } } Some(PollResult::Packet(packet)) => { let recv_instant = Instant::now(); @@ -141,7 +199,20 @@ impl Pinger { let packet_sequence_number = packet.sequence_number(); if let Some(subscriber) = subscribers.get(&packet_sequence_number) { - if subscriber.send((packet_source, recv_instant)).is_err() { + // An echo reply mirrors the request's payload, so + // a mismatch means the reply wasn't produced by + // this round (e.g. a reply to an older round whose + // sequence number collided after wraparound, or + // blindly spoofed cross-traffic). Discard it. + let payload_matches = + packet.payload() == &subscriber.expected_payload[..]; + + if payload_matches + && subscriber + .sender + .send((packet_source, recv_instant)) + .is_err() + { subscribers.remove(&packet_sequence_number); } } @@ -151,7 +222,10 @@ impl Pinger { } }); - Ok(Self { inner }) + Ok(Self { + inner, + round_sender: sender, + }) } /// Ping `addresses` @@ -159,6 +233,16 @@ impl Pinger { /// Creates [`MeasureManyStream`] which **lazily** sends ping /// requests and [`Stream`]s the responses as they arrive. /// + /// Replies are matched by source address, so an address that appears + /// multiple times is only pinged once per round and yields a single + /// measurement. + /// + /// # Panics + /// + /// Panics if the background receive task has terminated, which only + /// happens when the runtime the `Pinger` was created on has been + /// shut down. + /// /// [`Stream`]: futures_core::Stream pub fn measure_many(&self, addresses: I) -> MeasureManyStream<'_, V, I> where @@ -168,12 +252,25 @@ impl Pinger { let send_queue = addresses.into_iter().peekable(); let (sender, receiver) = mpsc::unbounded_channel(); - let sequence_number = self.inner.sequence_number.fetch_add(1, Ordering::AcqRel); + // Relaxed is enough: the counter is a pure id allocator, no other + // memory is synchronized through it. + let round_id = self.inner.next_round_id.fetch_add(1, Ordering::Relaxed); + + // The same packet is reused for every address of the round. Its + // random payload lets the receive task discard replies that don't + // belong to this round. + // + // The identifier is irrelevant: the kernel overwrites it with the + // socket's own identifier, which it also uses to route echo replies + // back to this socket. + let payload = rand::random::<[u8; 64]>(); + let packet = EchoRequestPacket::new(0, round_id as u16, &payload); + if self - .inner .round_sender .send(RoundMessage::Subscribe { - sequence_number, + round_id, + expected_payload: packet.payload(), sender, }) .is_err() @@ -183,10 +280,11 @@ impl Pinger { MeasureManyStream { pinger: self, + packet, send_queue, in_flight: HashMap::with_capacity(size_hint), receiver, - sequence_number, + round_id, } } } @@ -197,16 +295,19 @@ impl Pinger { /// like [`tokio::time::timeout`] should be used to prevent the program /// from hanging indefinitely. /// -/// Leaking this method might crate a slowly forever growing memory leak. +/// Leaking this stream may create a memory leak that lasts until the +/// [`Pinger`] is dropped. /// /// [`Stream`]: futures_core::Stream /// [`tokio::time::timeout`]: tokio::time::timeout +#[must_use = "streams do nothing unless polled"] pub struct MeasureManyStream<'a, V: IpVersion, I: Iterator> { pinger: &'a Pinger, + packet: EchoRequestPacket, send_queue: Peekable, in_flight: HashMap, receiver: mpsc::UnboundedReceiver<(V, Instant)>, - sequence_number: u16, + round_id: u64, } impl> MeasureManyStream<'_, V, I> { @@ -229,23 +330,40 @@ impl> MeasureManyStream<'_, V, I> { fn poll_next_icmp_replies(&mut self, cx: &mut Context<'_>) { while let Some(&addr) = self.send_queue.peek() { - let payload = rand::random::<[u8; 64]>(); - - let packet = EchoRequestPacket::new( - self.pinger.inner.identifier, - self.sequence_number, - &payload, - ); - match self.pinger.inner.raw.poll_send_to(cx, addr, &packet) { - Poll::Ready(_) => { + // Replies are matched by source address within a round, so a + // second ping to an address that is still awaiting its reply + // could never produce a second measurement; it would only + // clobber the first ping's start time. Skip the duplicate. + if self.in_flight.contains_key(&addr) { + self.send_queue.next(); + continue; + } + + match self.pinger.inner.raw.poll_send_to(cx, addr, &self.packet) { + Poll::Ready(result) => { let sent_at = Instant::now(); let taken_addr = self.send_queue.next(); debug_assert!(taken_addr.is_some()); - self.in_flight.insert(addr, sent_at); + // If the send failed (e.g. no route to host) no reply + // can ever arrive, so don't track the address as + // in-flight or the stream would never terminate. + if result.is_ok() { + self.in_flight.insert(addr, sent_at); + } + } + Poll::Pending => { + // The socket only remembers the most recent waker per + // direction (`AsyncFd` semantics), so with multiple + // streams sharing the socket another stream could + // overwrite ours and we'd never be woken again. Sends + // only return `Pending` while the send buffer is full, + // which clears up quickly, so schedule an immediate + // re-poll instead of parking. + cx.waker().wake_by_ref(); + break; } - Poll::Pending => break, } } } @@ -280,12 +398,8 @@ impl + Unpin> Stream for MeasureManyStream<' impl> Drop for MeasureManyStream<'_, V, I> { fn drop(&mut self) { - let _ = self - .pinger - .inner - .round_sender - .send(RoundMessage::Unsubscribe { - sequence_number: self.sequence_number, - }); + let _ = self.pinger.round_sender.send(RoundMessage::Unsubscribe { + round_id: self.round_id, + }); } } diff --git a/src/raw_pinger.rs b/src/raw_pinger.rs index fe7fb54..68bdb25 100644 --- a/src/raw_pinger.rs +++ b/src/raw_pinger.rs @@ -20,6 +20,10 @@ pub type RawV4Pinger = RawPinger; pub type RawV6Pinger = RawPinger; /// Asynchronous pinger +/// +/// The underlying socket remembers at most one waker per direction, so at +/// most one task should be sending and one task receiving at any given +/// time; concurrent polls in the same direction can lose wakeups. pub struct RawPinger { socket: Socket, _version: PhantomData, @@ -58,6 +62,8 @@ impl RawPinger { } /// Receive an ICMP ECHO reply packet + /// + /// Replies larger than 2048 bytes are silently truncated. pub fn recv(&self) -> RecvFuture<'_, V> { RecvFuture { pinger: self, @@ -66,14 +72,20 @@ impl RawPinger { } /// Receive an ICMP ECHO reply packet + /// + /// Replies larger than 2048 bytes are silently truncated. pub fn poll_recv( &self, buf: &mut BytesMut, cx: &mut Context<'_>, ) -> Poll>> { let (buf, source) = ready!(self.socket.poll_read(buf, cx))?; - let source = V::from_ip_addr(source.ip()).unwrap(); - match EchoReplyPacket::from_reply(source, buf) { + // Skip packets that don't come from the expected address family or + // aren't valid echo replies, and ask to be polled again: more + // packets may already be queued on the socket. + let packet = V::from_ip_addr(source.ip()) + .and_then(|source| EchoReplyPacket::from_reply(source, buf)); + match packet { Some(packet) => Poll::Ready(Ok(packet)), None => { cx.waker().wake_by_ref(); @@ -84,6 +96,7 @@ impl RawPinger { } /// [`Future`] obtained from [`RawPinger::send_to`]. +#[must_use = "futures do nothing unless you `.await` or poll them"] pub struct SendFuture<'a, V: IpVersion> { pinger: &'a RawPinger, addr: V, @@ -99,6 +112,7 @@ impl Future for SendFuture<'_, V> { } /// [`Future`] obtained from [`RawPinger::recv`]. +#[must_use = "futures do nothing unless you `.await` or poll them"] pub struct RecvFuture<'a, V: IpVersion> { pinger: &'a RawPinger, buf: BytesMut, @@ -109,7 +123,6 @@ impl Future for RecvFuture<'_, V> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let packet = ready!(self.pinger.poll_recv(&mut self.buf, cx))?; - // SAFETY: `RawPinger` already checked that the packet is valid Poll::Ready(Ok(packet)) } } diff --git a/src/socket/mod.rs b/src/socket/mod.rs index c1f4fbf..6f1d9cc 100644 --- a/src/socket/mod.rs +++ b/src/socket/mod.rs @@ -36,7 +36,8 @@ impl Socket { match guard.try_io(|inner| inner.get_ref().recv(buf.spare_capacity_mut())) { Ok(Ok((n, source))) => { - // SAFETY: `poll_recv` guarantees that `n` has been filled + // SAFETY: `BaseSocket::recv` guarantees that the first + // `n` bytes of the spare capacity have been filled unsafe { buf.advance_mut(n) } return Poll::Ready(Ok((buf.split().freeze(), source))); diff --git a/tests/drop_behavior.rs b/tests/drop_behavior.rs new file mode 100644 index 0000000..1a02271 --- /dev/null +++ b/tests/drop_behavior.rs @@ -0,0 +1,48 @@ +//! Regression test for resource cleanup on `Pinger` drop. +//! +//! The background receive task used to hold (through `Arc`) +//! the only sender of its own subscription channel, so the channel never +//! disconnected, the task never exited and the ICMP socket was never +//! closed. Every `Pinger` created and dropped leaked a task and a file +//! descriptor for the lifetime of the runtime. + +use massping::V4Pinger; + +fn count_fds() -> usize { + std::fs::read_dir("/proc/self/fd").unwrap().count() +} + +#[tokio::test(flavor = "current_thread")] +async fn dropping_pinger_closes_socket_and_stops_task() { + // Warm up anything created lazily (runtime resources, RNG, ...) so the + // baseline below is stable. + drop(V4Pinger::new().unwrap()); + for _ in 0..100 { + tokio::task::yield_now().await; + } + let baseline = count_fds(); + + let pingers: Vec<_> = (0..5).map(|_| V4Pinger::new().unwrap()).collect(); + assert!( + count_fds() >= baseline + 5, + "each pinger should hold a socket fd" + ); + drop(pingers); + + // The background tasks observe the subscription channel closing and + // exit, dropping their sockets. Yield so the current_thread runtime + // gets a chance to run them. + let mut fds = count_fds(); + for _ in 0..100 { + if fds == baseline { + break; + } + tokio::task::yield_now().await; + fds = count_fds(); + } + + assert_eq!( + fds, baseline, + "background task/socket leaked after dropping the pingers" + ); +} diff --git a/tests/sequence_wraparound.rs b/tests/sequence_wraparound.rs new file mode 100644 index 0000000..5cee11f --- /dev/null +++ b/tests/sequence_wraparound.rs @@ -0,0 +1,41 @@ +#![cfg(feature = "stream")] + +//! Regression test for sequence number wraparound. +//! +//! Rounds are identified on the wire by a `u16` sequence number, which +//! wraps around after 65536 `measure_many` calls. A long-lived round +//! dropped *after* the wraparound used to unsubscribe whichever newer +//! round had taken over its sequence number, leaving that round unable +//! to ever receive its replies. + +use std::{iter, net::Ipv4Addr, time::Duration}; + +use futures_util::StreamExt; +use massping::V4Pinger; +use tokio::time; + +#[tokio::test(flavor = "current_thread")] +async fn unsubscribe_after_wraparound_does_not_break_new_round() { + let pinger = V4Pinger::new().unwrap(); + + // Round with sequence number 0, kept alive across the wraparound. + let stale = pinger.measure_many(iter::empty::()); + + // Burn through the remaining 65535 sequence numbers. + for _ in 0..65535 { + drop(pinger.measure_many(iter::empty::())); + } + + // This round wraps around to sequence number 0, taking over the slot. + let mut current = pinger.measure_many([Ipv4Addr::LOCALHOST].into_iter()); + + // Dropping the stale round must not unsubscribe the current one. + drop(stale); + + let result = time::timeout(Duration::from_secs(5), current.next()).await; + match result { + Ok(Some((addr, _rtt))) => assert_eq!(addr, Ipv4Addr::LOCALHOST), + Ok(None) => panic!("stream ended without a reply"), + Err(_) => panic!("no reply: the stale unsubscribe broke the current round"), + } +} diff --git a/tests/stream_termination.rs b/tests/stream_termination.rs index 60a4725..20d4dac 100644 --- a/tests/stream_termination.rs +++ b/tests/stream_termination.rs @@ -4,10 +4,13 @@ //! ping requests have been sent and all responses have been received. Without //! this, `while let Some(...) = stream.next().await` loops hang forever. -use std::{net::IpAddr, time::Duration}; +use std::{ + net::{IpAddr, Ipv4Addr}, + time::Duration, +}; use futures_util::StreamExt; -use massping::DualstackPinger; +use massping::{DualstackPinger, V4Pinger}; use tokio::time; /// Test that the stream properly terminates after receiving all responses. @@ -42,9 +45,8 @@ async fn stream_terminates_after_single_ping() { /// Test that the stream properly terminates after receiving multiple responses. /// -/// Note: We use different addresses because `in_flight` is keyed by address, -/// so pinging the same address multiple times in one `measure_many` call -/// would overwrite the previous entry. +/// Note: We use different addresses because replies are matched by source +/// address, so duplicate addresses yield a single measurement per round. #[tokio::test(flavor = "current_thread")] async fn stream_terminates_after_multiple_pings() { let pinger = DualstackPinger::new().unwrap(); @@ -74,6 +76,57 @@ async fn stream_terminates_after_multiple_pings() { assert_eq!(count, 3, "expected exactly 3 ping responses"); } +/// Test that duplicate addresses yield a single measurement and terminate. +/// +/// Replies are matched by source address within a round, so a duplicate +/// can never yield a second measurement; it is only pinged once. +#[tokio::test(flavor = "current_thread")] +async fn duplicate_addresses_yield_single_measurement() { + let pinger = DualstackPinger::new().unwrap(); + let localhost: IpAddr = "127.0.0.1".parse().unwrap(); + let mut stream = pinger.measure_many([localhost, localhost, localhost].into_iter()); + + let mut count = 0; + + let result = time::timeout(Duration::from_secs(5), async { + while stream.next().await.is_some() { + count += 1; + } + }) + .await; + + assert!(result.is_ok(), "stream did not terminate"); + assert_eq!(count, 1, "duplicates should collapse into one measurement"); +} + +/// Test that the stream terminates when sends fail immediately. +/// +/// `sendto` fails synchronously with `EACCES` for the broadcast address +/// (the socket doesn't have `SO_BROADCAST`), so no reply can ever arrive. +/// The stream must not wait for one forever. +#[tokio::test(flavor = "current_thread")] +async fn stream_terminates_when_send_fails() { + let pinger = V4Pinger::new().unwrap(); + + let addresses: [Ipv4Addr; 1] = ["255.255.255.255".parse().unwrap()]; + let mut stream = pinger.measure_many(addresses.into_iter()); + + let mut count = 0; + + let result = time::timeout(Duration::from_secs(2), async { + while stream.next().await.is_some() { + count += 1; + } + }) + .await; + + assert!( + result.is_ok(), + "stream did not terminate after failed sends" + ); + assert_eq!(count, 0, "expected no responses for unsendable addresses"); +} + /// Test that an empty address list terminates immediately. #[tokio::test(flavor = "current_thread")] async fn stream_terminates_with_empty_input() {