From 4697c137257ac0e995498b82735fc1f87038a9dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Sun, 16 Nov 2025 14:50:04 +0100 Subject: [PATCH 01/26] make context ref mut --- example-service/src/client.rs | 7 +- example-service/src/server.rs | 2 +- plugins/src/lib.rs | 8 +-- plugins/tests/service.rs | 14 ++-- tarpc/examples/compression.rs | 4 +- tarpc/examples/custom_transport.rs | 4 +- tarpc/examples/pubsub.rs | 23 ++++--- tarpc/examples/readme.rs | 4 +- tarpc/examples/tls_over_tcp.rs | 4 +- tarpc/examples/tracing.rs | 10 +-- tarpc/src/client.rs | 9 ++- tarpc/src/client/stub.rs | 6 +- tarpc/src/client/stub/load_balance.rs | 10 +-- tarpc/src/client/stub/mock.rs | 2 +- tarpc/src/client/stub/retry.rs | 2 +- tarpc/src/context.rs | 2 +- tarpc/src/lib.rs | 9 +-- tarpc/src/server.rs | 64 ++++++++++--------- tarpc/src/server/incoming.rs | 5 +- tarpc/src/server/request_hook.rs | 22 ++++--- tarpc/src/server/request_hook/after.rs | 4 +- tarpc/src/server/request_hook/before.rs | 16 +++-- .../server/request_hook/before_and_after.rs | 6 +- tarpc/src/transport/channel.rs | 6 +- tarpc/tests/dataservice.rs | 4 +- tarpc/tests/service_functional.rs | 52 +++++++++------ 26 files changed, 164 insertions(+), 135 deletions(-) diff --git a/example-service/src/client.rs b/example-service/src/client.rs index 8a4ff72eb..c73122c07 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -34,10 +34,13 @@ async fn main() -> anyhow::Result<()> { let client = WorldClient::new(client::Config::default(), transport.await?).spawn(); let hello = async move { + let mut context = context::current(); + let mut context2 = context::current(); + // Send the request twice, just to be safe! ;) tokio::select! { - hello1 = client.hello(context::current(), format!("{}1", flags.name)) => { hello1 } - hello2 = client.hello(context::current(), format!("{}2", flags.name)) => { hello2 } + hello1 = client.hello(&mut context, format!("{}1", flags.name)) => { hello1 } + hello2 = client.hello(&mut context2, format!("{}2", flags.name)) => { hello2 } } } .instrument(tracing::info_span!("Two Hellos")) diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 896280c3d..1efab549d 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -35,7 +35,7 @@ struct Flags { struct HelloServer(SocketAddr); impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { + async fn hello(self, _: &mut context::Context, name: String) -> String { let sleep_time = Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng())); time::sleep(sleep_time).await; diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index da6443edf..55ec2730e 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -401,7 +401,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// #[derive(Clone)] /// struct CalculatorServer; /// impl Calculator for CalculatorServer { -/// async fn add(self, context: Context, a: i32, b: i32) -> i32 { +/// async fn add(self, context: &mut Context, a: i32, b: i32) -> i32 { /// a + b /// } /// } @@ -558,7 +558,7 @@ impl ServiceGenerator<'_> { )| { quote! { #( #attrs )* - async fn #ident(self, context: ::tarpc::context::Context, #( #args ),*) -> #output; + async fn #ident(self, context: &mut ::tarpc::context::Context, #( #args ),*) -> #output; } }, ); @@ -622,7 +622,7 @@ impl ServiceGenerator<'_> { type Resp = #response_ident; - async fn serve(self, ctx: ::tarpc::context::Context, req: #request_ident) + async fn serve(self, ctx: &mut ::tarpc::context::Context, req: #request_ident) -> ::core::result::Result<#response_ident, ::tarpc::ServerError> { match req { #( @@ -786,7 +786,7 @@ impl ServiceGenerator<'_> { #( #[allow(unused)] #( #method_attrs )* - #vis fn #method_idents(&self, ctx: ::tarpc::context::Context, #( #args ),*) + #vis fn #method_idents<'a>(&'a self, ctx: &'a mut ::tarpc::context::Context, #( #args ),*) -> impl ::core::future::Future> + '_ { let request = #request_ident::#camel_case_idents { #( #arg_pats ),* }; let resp = self.0.call(ctx, request); diff --git a/plugins/tests/service.rs b/plugins/tests/service.rs index 26ee1ec39..d38492bd7 100644 --- a/plugins/tests/service.rs +++ b/plugins/tests/service.rs @@ -12,15 +12,15 @@ fn att_service_trait() { } impl Foo for () { - async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) { + async fn two_part(self, _: &mut context::Context, s: String, i: i32) -> (String, i32) { (s, i) } - async fn bar(self, _: context::Context, s: String) -> String { + async fn bar(self, _: &mut context::Context, s: String) -> String { s } - async fn baz(self, _: context::Context) {} + async fn baz(self, _: &mut context::Context) {} } } @@ -39,18 +39,18 @@ fn raw_idents() { impl r#trait for () { async fn r#await( self, - _: context::Context, + _: &mut context::Context, r#struct: r#yield, r#enum: i32, ) -> (r#yield, i32) { (r#struct, r#enum) } - async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield { + async fn r#fn(self, _: &mut context::Context, r#impl: r#yield) -> r#yield { r#impl } - async fn r#async(self, _: context::Context) {} + async fn r#async(self, _: &mut context::Context) {} } } @@ -64,7 +64,7 @@ fn service_with_cfg_rpc() { } impl Foo for () { - async fn foo(self, _: context::Context) {} + async fn foo(self, _: &mut context::Context) {} } } diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index d66261d19..783f2618a 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -108,7 +108,7 @@ pub trait World { struct HelloServer; impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { + async fn hello(self, _: &mut context::Context, name: String) -> String { format!("Hey, {name}!") } } @@ -134,7 +134,7 @@ async fn main() -> anyhow::Result<()> { println!( "{}", - client.hello(context::current(), "friend".into()).await? + client.hello(&mut context::current(), "friend".into()).await? ); Ok(()) } diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index 5f5386785..c99825d08 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -21,7 +21,7 @@ pub trait PingService { struct Service; impl PingService for Service { - async fn ping(self, _: Context) {} + async fn ping(self, _: &mut Context) {} } #[tokio::main] @@ -52,7 +52,7 @@ async fn main() -> anyhow::Result<()> { let transport = transport::new(codec_builder.new_framed(conn), Bincode::default()); PingServiceClient::new(Default::default(), transport) .spawn() - .ping(tarpc::context::current()) + .ping(&mut tarpc::context::current()) .await?; Ok(()) diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index d61f68c48..c89f9e736 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -80,11 +80,11 @@ struct Subscriber { } impl subscriber::Subscriber for Subscriber { - async fn topics(self, _: context::Context) -> Vec { + async fn topics(self, _: &mut context::Context) -> Vec { self.topics.clone() } - async fn receive(self, _: context::Context, topic: String, message: String) { + async fn receive(self, _: &mut context::Context, topic: String, message: String) { info!(local_addr = %self.local_addr, %topic, %message, "ReceivedMessage") } } @@ -210,7 +210,7 @@ impl Publisher { subscriber: subscriber::SubscriberClient, ) { // Populate the topics - if let Ok(topics) = subscriber.topics(context::current()).await { + if let Ok(topics) = subscriber.topics(&mut context::current()).await { self.clients.lock().unwrap().insert( subscriber_addr, Subscription { @@ -263,15 +263,20 @@ impl Publisher { } impl publisher::Publisher for Publisher { - async fn publish(self, _: context::Context, topic: String, message: String) { + async fn publish(self, _: &mut context::Context, topic: String, message: String) { info!("received message to publish."); let mut subscribers = match self.subscriptions.read().unwrap().get(&topic) { None => return, Some(subscriptions) => subscriptions.clone(), }; let mut publications = Vec::new(); + + for client in subscribers.values_mut() { - publications.push(client.receive(context::current(), topic.clone(), message.clone())); + publications.push(async { + let mut context = context::current(); + client.receive(&mut context, topic.clone(), message.clone()).await + }); } // Ignore failing subscribers. In a real pubsub, you'd want to continually retry until // subscribers ack. Of course, a lot would be different in a real pubsub :) @@ -342,26 +347,26 @@ async fn main() -> anyhow::Result<()> { .spawn(); publisher - .publish(context::current(), "calculus".into(), "sqrt(2)".into()) + .publish(&mut context::current(), "calculus".into(), "sqrt(2)".into()) .await?; publisher .publish( - context::current(), + &mut context::current(), "cool shorts".into(), "hello to all".into(), ) .await?; publisher - .publish(context::current(), "history".into(), "napoleon".to_string()) + .publish(&mut context::current(), "history".into(), "napoleon".to_string()) .await?; drop(_subscriber0); publisher .publish( - context::current(), + &mut context::current(), "cool shorts".into(), "hello to who?".into(), ) diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index c328bd884..bb3deadc7 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -23,7 +23,7 @@ pub trait World { struct HelloServer; impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { + async fn hello(self, _: &mut context::Context, name: String) -> String { format!("Hello, {name}!") } } @@ -46,7 +46,7 @@ async fn main() -> anyhow::Result<()> { // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context // specifies a deadline and trace information which can be helpful in debugging requests. - let hello = client.hello(context::current(), "Stim".to_string()).await?; + let hello = client.hello(&mut context::current(), "Stim".to_string()).await?; println!("{hello}"); diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index 968f76c17..d81ea74a1 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -33,7 +33,7 @@ pub trait PingService { struct Service; impl PingService for Service { - async fn ping(self, _: Context) -> String { + async fn ping(self, _: &mut Context) -> String { "🔒".to_owned() } } @@ -146,7 +146,7 @@ async fn main() -> anyhow::Result<()> { let transport = transport::new(codec_builder.new_framed(stream), Bincode::default()); let answer = PingServiceClient::new(Default::default(), transport) .spawn() - .ping(tarpc::context::current()) + .ping(&mut tarpc::context::current()) .await?; println!("ping answer: {answer}"); diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index 79a7026c0..1bace43ce 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -56,7 +56,7 @@ pub mod double { struct AddServer; impl AddService for AddServer { - async fn add(self, _: context::Context, x: i32, y: i32) -> i32 { + async fn add(self, _: &mut context::Context, x: i32, y: i32) -> i32 { x + y } } @@ -70,9 +70,9 @@ impl DoubleService for DoubleServer where Stub: AddStub + Clone + Send + Sync + 'static, { - async fn double(self, _: context::Context, x: i32) -> Result { + async fn double(self, _: &mut context::Context, x: i32) -> Result { self.add_client - .add(context::current(), x, x) + .add(&mut context::current(), x, x) .await .map_err(|e| e.to_string()) } @@ -193,9 +193,9 @@ async fn main() -> anyhow::Result<()> { let double_client = double::DoubleClient::new(client::Config::default(), to_double_server).spawn(); - let ctx = context::current(); + let mut ctx = context::current(); for _ in 1..=5 { - tracing::info!("{:?}", double_client.double(ctx, 1).await?); + tracing::info!("{:?}", double_client.double(&mut ctx, 1).await?); } tracer_provider.shutdown()?; diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 3cf9ff07a..96afc4c5f 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -128,7 +128,7 @@ where otel.kind = "client", otel.name = %request.name()) )] - pub async fn call(&self, mut ctx: context::Context, request: Req) -> Result { + pub async fn call(&self, ctx: &mut context::Context, request: Req) -> Result { let span = Span::current(); ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { tracing::trace!( @@ -153,7 +153,10 @@ where }; self.to_dispatch .send(DispatchRequest { - ctx, + ctx: context::Context { + deadline: ctx.deadline, + trace_context: ctx.trace_context.clone(), + }, span, request_id, request, @@ -881,7 +884,7 @@ mod tests { let (dispatch, channel, _server_channel) = set_up(); drop(dispatch); // error on send - let resp = channel.call(current(), "hi".to_string()).await; + let resp = channel.call(&mut current(), "hi".to_string()).await; assert_matches!(resp, Err(RpcError::Shutdown)); } diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index 85746b7f2..2647c1321 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -24,7 +24,7 @@ pub trait Stub { type Resp; /// Calls a remote service. - async fn call(&self, ctx: context::Context, request: Self::Req) + async fn call(&self, ctx: &mut context::Context, request: Self::Req) -> Result; } @@ -35,7 +35,7 @@ where type Req = Req; type Resp = Resp; - async fn call(&self, ctx: context::Context, request: Req) -> Result { + async fn call(&self, ctx: &mut context::Context, request: Req) -> Result { Self::call(self, ctx, request).await } } @@ -46,7 +46,7 @@ where { type Req = S::Req; type Resp = S::Resp; - async fn call(&self, ctx: context::Context, req: Self::Req) -> Result { + async fn call(&self, ctx: &mut context::Context, req: Self::Req) -> Result { self.clone().serve(ctx, req).await.map_err(RpcError::Server) } } diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs index d28a3c137..6c0f7b0df 100644 --- a/tarpc/src/client/stub/load_balance.rs +++ b/tarpc/src/client/stub/load_balance.rs @@ -20,7 +20,7 @@ mod round_robin { async fn call( &self, - ctx: context::Context, + ctx: &mut context::Context, request: Self::Req, ) -> Result { let next = self.stubs.next(); @@ -119,7 +119,7 @@ mod consistent_hash { async fn call( &self, - ctx: context::Context, + ctx: &mut context::Context, request: Self::Req, ) -> Result { let index = usize::try_from(self.hasher.hash_one(&request) % self.stubs_len).expect( @@ -200,13 +200,13 @@ mod consistent_hash { )?; for _ in 0..2 { - let resp = stub.call(context::current(), 'a').await?; + let resp = stub.call(&mut context::current(), 'a').await?; assert_eq!(resp, 1); - let resp = stub.call(context::current(), 'b').await?; + let resp = stub.call(&mut context::current(), 'b').await?; assert_eq!(resp, 2); - let resp = stub.call(context::current(), 'c').await?; + let resp = stub.call(&mut context::current(), 'c').await?; assert_eq!(resp, 3); } diff --git a/tarpc/src/client/stub/mock.rs b/tarpc/src/client/stub/mock.rs index 145c14c1f..6f0540797 100644 --- a/tarpc/src/client/stub/mock.rs +++ b/tarpc/src/client/stub/mock.rs @@ -30,7 +30,7 @@ where type Req = Req; type Resp = Resp; - async fn call(&self, _: context::Context, request: Self::Req) -> Result { + async fn call(&self, _: &mut context::Context, request: Self::Req) -> Result { self.responses .get(&request) .cloned() diff --git a/tarpc/src/client/stub/retry.rs b/tarpc/src/client/stub/retry.rs index a07b05fc5..18c84f25f 100644 --- a/tarpc/src/client/stub/retry.rs +++ b/tarpc/src/client/stub/retry.rs @@ -18,7 +18,7 @@ where async fn call( &self, - ctx: context::Context, + ctx: &mut context::Context, request: Self::Req, ) -> Result { let request = Arc::new(request); diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index 8e77cf223..f59d34dd9 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -21,7 +21,7 @@ use tracing_opentelemetry::OpenTelemetrySpanExt; /// /// The context should not be stored directly in a server implementation, because the context will /// be different for each request in scope. -#[derive(Clone, Copy, Debug)] +#[derive(Debug)] #[non_exhaustive] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Context { diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 7e1944305..17a06ec57 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -125,7 +125,7 @@ //! //! impl World for HelloServer { //! // Each defined rpc generates an async fn that serves the RPC -//! async fn hello(self, _: context::Context, name: String) -> String { +//! async fn hello(self, _: &mut context::Context, name: String) -> String { //! format!("Hello, {name}!") //! } //! } @@ -158,7 +158,7 @@ //! # struct HelloServer; //! # impl World for HelloServer { //! // Each defined rpc generates an async fn that serves the RPC -//! # async fn hello(self, _: context::Context, name: String) -> String { +//! # async fn hello(self, _: &mut context::Context, name: String) -> String { //! # format!("Hello, {name}!") //! # } //! # } @@ -184,7 +184,8 @@ //! // The client has an RPC method for each RPC defined in the annotated trait. It takes the same //! // args as defined, with the addition of a Context, which is always the first arg. The Context //! // specifies a deadline and trace information which can be helpful in debugging requests. -//! let hello = client.hello(context::current(), "Stim".to_string()).await?; +//! let mut context = context::current(); +//! let hello = client.hello(&mut context, "Stim".to_string()).await?; //! //! println!("{hello}"); //! @@ -279,7 +280,7 @@ pub enum ClientMessage { } /// A request from a client to a server. -#[derive(Clone, Copy, Debug)] +#[derive(Debug)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Request { /// Trace context, deadline, and other cross-cutting concerns. diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index da3b3ae21..d0cca7ad4 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -76,7 +76,7 @@ pub trait Serve { type Resp; /// Responds to a single request. - async fn serve(self, ctx: context::Context, req: Self::Req) -> Result; + async fn serve(self, ctx: &mut context::Context, req: Self::Req) -> Result; } /// A Serve wrapper around a Fn. @@ -102,10 +102,9 @@ impl Copy for ServeFn where F: Copy {} /// Creates a [`Serve`] wrapper around a `FnOnce(context::Context, Req) -> impl Future>`. -pub fn serve(f: F) -> ServeFn +pub fn serve(f: F) -> ServeFn where - F: FnOnce(context::Context, Req) -> Fut, - Fut: Future>, + for<'a> F: FnOnce(&'a mut context::Context, Req) -> Pin> + 'a + Send>>, { ServeFn { f, @@ -113,16 +112,15 @@ where } } -impl Serve for ServeFn +impl Serve for ServeFn where Req: RequestName, - F: FnOnce(context::Context, Req) -> Fut, - Fut: Future>, + for<'a> F: FnOnce(&'a mut context::Context, Req) -> Pin> + 'a + Send>>, { type Req = Req; type Resp = Resp; - async fn serve(self, ctx: context::Context, req: Req) -> Result { + async fn serve(self, ctx: &mut context::Context, req: Req) -> Result { (self.f)(ctx, req).await } } @@ -360,10 +358,11 @@ where /// let mut requests = server.requests(); /// tokio::spawn(async move { /// while let Some(Ok(request)) = requests.next().await { - /// tokio::spawn(request.execute(serve(|_, i| async move { Ok(i + 1) }))); + /// tokio::spawn(request.execute(serve(|_, i| async move { Ok(i + 1) }.boxed()))); /// } /// }); - /// assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); + /// let mut context = context::current(); + /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` fn requests(self) -> Requests @@ -399,12 +398,13 @@ where /// let client = client::new(client::Config::default(), tx).spawn(); /// let channel = BaseChannel::with_defaults(rx); /// tokio::spawn( - /// channel.execute(serve(|_, i: i32| async move { Ok(i + 1) })) + /// channel.execute(serve(|_, i: i32| async move { Ok(i + 1) }.boxed())) /// .for_each(|response| async move { /// tokio::spawn(response); - /// })); + /// }.boxed())); + /// let mut context = context::current(); /// assert_eq!( - /// client.call(context::current(), 1).await.unwrap(), + /// client.call(&mut context, 1).await.unwrap(), /// 2); /// } /// ``` @@ -748,11 +748,12 @@ where /// let requests = BaseChannel::new(server::Config::default(), rx).requests(); /// let client = client::new(client::Config::default(), tx).spawn(); /// tokio::spawn( - /// requests.execute(serve(|_, i| async move { Ok(i + 1) })) + /// requests.execute(serve(|_, i| async move { Ok(i + 1) }.boxed())) /// .for_each(|response| async move { /// tokio::spawn(response); - /// })); - /// assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); + /// }.boxed())); + /// let mut context = context::current(); + /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` pub fn execute(self, serve: S) -> impl Stream> @@ -855,11 +856,11 @@ impl InFlightRequest { /// tokio::spawn(async move { /// let mut requests = server.requests(); /// while let Some(Ok(in_flight_request)) = requests.next().await { - /// in_flight_request.execute(serve(|_, i| async move { Ok(i + 1) })).await; + /// in_flight_request.execute(serve(|_, i| async move { Ok(i + 1) }.boxed())).await; /// } - /// /// }); - /// assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); + /// let mut context = context::current(); + /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` /// @@ -875,7 +876,7 @@ impl InFlightRequest { span, request: Request { - context, + mut context, message, id: request_id, }, @@ -883,7 +884,7 @@ impl InFlightRequest { span.record("otel.name", message.name()); let _ = Abortable::new( async move { - let message = serve.serve(context, message).await; + let message = serve.serve(&mut context, message).await; tracing::debug!("CompleteRequest"); let response = Response { request_id, @@ -977,6 +978,7 @@ mod tests { task::Poll, time::{Duration, Instant}, }; + use tracing_subscriber::filter::FilterExt; fn test_channel() -> ( Pin, Response>>>>, @@ -1039,8 +1041,8 @@ mod tests { #[tokio::test] async fn test_serve() { - let serve = serve(|_, i| async move { Ok(i) }); - assert_matches!(serve.serve(context::current(), 7).await, Ok(7)); + let serve = serve(|_, i| async move { Ok(i) }.boxed()); + assert_matches!(serve.serve(&mut context::current(), 7).await, Ok(7)); } #[tokio::test] @@ -1060,14 +1062,14 @@ mod tests { let some_time = Instant::now() + Duration::from_secs(37); let some_other_time = Instant::now() + Duration::from_secs(83); - let serve = serve(move |ctx: context::Context, i| async move { + let serve = serve(move |ctx: &mut context::Context, i| async move { assert_eq!(ctx.deadline, some_time); Ok(i) - }); + }.boxed()); let deadline_hook = serve.before(SetDeadline(some_time)); let mut ctx = context::current(); ctx.deadline = some_other_time; - deadline_hook.serve(ctx, 7).await?; + deadline_hook.serve(&mut ctx, 7).await?; Ok(()) } @@ -1101,21 +1103,21 @@ mod tests { } } - let serve = serve(move |_: context::Context, i| async move { Ok(i) }); + let serve = serve(move |_: &mut context::Context, i| async move { Ok(i) }.boxed()); serve .before_and_after(PrintLatency::new()) - .serve(context::current(), 7) + .serve(&mut context::current(), 7) .await?; Ok(()) } #[tokio::test] async fn serve_before_error_aborts_request() -> anyhow::Result<()> { - let serve = serve(|_, _| async { panic!("Shouldn't get here") }); + let serve = serve(|_, _| async { panic!("Shouldn't get here") }.boxed()); let deadline_hook = serve.before(|_: &mut context::Context, _: &i32| async { Err(ServerError::new(io::ErrorKind::Other, "oops".into())) }); - let resp: Result = deadline_hook.serve(context::current(), 7).await; + let resp: Result = deadline_hook.serve(&mut context::current(), 7).await; assert_matches!(resp, Err(_)); Ok(()) } @@ -1320,7 +1322,7 @@ mod tests { Poll::Ready(Some(Ok(request))) => request, result => panic!("Unexpected result: {result:?}"), }; - request.execute(serve(|_, _| async { Ok(()) })).await; + request.execute(serve(|_, _| async { Ok(()) }.boxed())).await; assert!( requests .as_mut() diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index 428eb1a7d..eddf3794e 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -63,9 +63,10 @@ where /// /// let incoming = stream::once(async move { /// BaseChannel::new(server::Config::default(), rx) -/// }).execute(serve(|_, i| async move { Ok(i + 1) })); +/// }).execute(serve(|_, i| async move { Ok(i + 1) }.boxed())); /// tokio::spawn(spawn_incoming(incoming)); -/// assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); +/// let mut context = context::current(); +/// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` pub async fn spawn_incoming( diff --git a/tarpc/src/server/request_hook.rs b/tarpc/src/server/request_hook.rs index 66cf2878c..64b97453a 100644 --- a/tarpc/src/server/request_hook.rs +++ b/tarpc/src/server/request_hook.rs @@ -43,11 +43,11 @@ pub trait RequestHook: Serve { /// # Example /// /// ```rust - /// use futures::{executor::block_on, future}; + /// use futures::{executor::block_on, future, FutureExt}; /// use tarpc::{context, ServerError, server::{Serve, request_hook::RequestHook, serve}}; /// use std::io; /// - /// let serve = serve(|_ctx, i| async move { Ok(i + 1) }) + /// let serve = serve(|_ctx, i| async move { Ok(i + 1) }.boxed()) /// .before(|_ctx: &mut context::Context, req: &i32| { /// future::ready( /// if *req == 1 { @@ -58,7 +58,8 @@ pub trait RequestHook: Serve { /// Ok(()) /// }) /// }); - /// let response = serve.serve(context::current(), 1); + /// let mut context = context::current(); + /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` fn before(self, hook: Hook) -> HookThenServe @@ -80,7 +81,7 @@ pub trait RequestHook: Serve { /// # Example /// /// ```rust - /// use futures::{executor::block_on, future}; + /// use futures::{executor::block_on, future, FutureExt}; /// use tarpc::{context, ServerError, server::{Serve, request_hook::RequestHook, serve}}; /// use std::io; /// @@ -93,15 +94,15 @@ pub trait RequestHook: Serve { /// } else { /// Ok(i + 1) /// } - /// }) + /// }.boxed()) /// .after(|_ctx: &mut context::Context, resp: &mut Result| { /// if let Err(e) = resp { /// eprintln!("server error: {e:?}"); /// } /// future::ready(()) /// }); - /// - /// let response = serve.serve(context::current(), 1); + /// let mut context = context::current(); + /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` fn after(self, hook: Hook) -> ServeThenHook @@ -123,7 +124,7 @@ pub trait RequestHook: Serve { /// # Example /// /// ```rust - /// use futures::{executor::block_on, future}; + /// use futures::{executor::block_on, future, FutureExt}; /// use tarpc::{ /// context, ServerError, /// server::{Serve, serve, request_hook::{BeforeRequest, AfterRequest, RequestHook}} @@ -151,8 +152,9 @@ pub trait RequestHook: Serve { /// /// let serve = serve(|_ctx, i| async move { /// Ok(i + 1) - /// }).before_and_after(PrintLatency(Instant::now())); - /// let response = serve.serve(context::current(), 1); + /// }.boxed()).before_and_after(PrintLatency(Instant::now())); + /// let mut context = context::current(); + /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_ok()); /// ``` fn before_and_after( diff --git a/tarpc/src/server/request_hook/after.rs b/tarpc/src/server/request_hook/after.rs index b2ef9ccbd..e2c49b2f1 100644 --- a/tarpc/src/server/request_hook/after.rs +++ b/tarpc/src/server/request_hook/after.rs @@ -59,14 +59,14 @@ where async fn serve( self, - mut ctx: context::Context, + ctx: &mut context::Context, req: Serv::Req, ) -> Result { let ServeThenHook { serve, mut hook, .. } = self; let mut resp = serve.serve(ctx, req).await; - hook.after(&mut ctx, &mut resp).await; + hook.after(ctx, &mut resp).await; resp } } diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs index e72e28a42..ad04cc784 100644 --- a/tarpc/src/server/request_hook/before.rs +++ b/tarpc/src/server/request_hook/before.rs @@ -87,13 +87,13 @@ where async fn serve( self, - mut ctx: context::Context, + ctx: &mut context::Context, req: Self::Req, ) -> Result { let HookThenServe { serve, mut hook, .. } = self; - hook.before(&mut ctx, &req).await?; + hook.before(ctx, &req).await?; serve.serve(ctx, req).await } } @@ -103,7 +103,7 @@ where /// Example /// /// ```rust -/// use futures::{executor::block_on, future}; +/// use futures::{executor::block_on, future, FutureExt}; /// use tarpc::{context, ServerError, server::{Serve, serve, request_hook::{self, /// BeforeRequest, BeforeRequestList}}}; /// use std::{cell::Cell, io}; @@ -120,8 +120,9 @@ where /// i.set(2); /// Ok(()) /// }) -/// .serving(serve(|_ctx, i| async move { Ok(i + 1) })); -/// let response = serve.clone().serve(context::current(), 1); +/// .serving(serve(|_ctx, i| async move { Ok(i + 1) }.boxed())); +/// let mut context = context::current(); +/// let response = serve.clone().serve(&mut context, 1); /// assert!(block_on(response).is_ok()); /// assert!(i.get() == 2); /// ``` @@ -209,8 +210,9 @@ fn before_request_list() { i.set(2); Ok(()) }) - .serving(serve(|_ctx, i| async move { Ok(i + 1) })); - let response = serve.clone().serve(context::current(), 1); + .serving(serve(|_ctx, i| async move { Ok(i + 1) }.boxed())); + let mut context = context::current(); + let response = serve.clone().serve(&mut context, 1); assert!(block_on(response).is_ok()); assert!(i.get() == 2); } diff --git a/tarpc/src/server/request_hook/before_and_after.rs b/tarpc/src/server/request_hook/before_and_after.rs index 0761a7df3..e06f34113 100644 --- a/tarpc/src/server/request_hook/before_and_after.rs +++ b/tarpc/src/server/request_hook/before_and_after.rs @@ -46,13 +46,13 @@ where type Req = Req; type Resp = Resp; - async fn serve(self, mut ctx: context::Context, req: Req) -> Result { + async fn serve(self, ctx: &mut context::Context, req: Req) -> Result { let HookThenServeThenHook { serve, mut hook, .. } = self; - hook.before(&mut ctx, &req).await?; + hook.before(ctx, &req).await?; let mut resp = serve.serve(ctx, req).await; - hook.after(&mut ctx, &mut resp).await; + hook.after(ctx, &mut resp).await; resp } } diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 0268300dc..3c0c420aa 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -198,7 +198,7 @@ mod tests { format!("{request:?} is not an int"), ) }) - })) + }.boxed())) .for_each(|channel| async move { tokio::spawn(channel.for_each(|response| response)); }), @@ -206,8 +206,8 @@ mod tests { let client = client::new(client::Config::default(), client_channel).spawn(); - let response1 = client.call(context::current(), "123".into()).await; - let response2 = client.call(context::current(), "abc".into()).await; + let response1 = client.call(&mut context::current(), "123".into()).await; + let response2 = client.call(&mut context::current(), "abc".into()).await; trace!("response1: {:?}, response2: {:?}", response1, response2); diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 18bb3a997..e051b434e 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -22,7 +22,7 @@ pub trait ColorProtocol { struct ColorServer; impl ColorProtocol for ColorServer { - async fn get_opposite_color(self, _: context::Context, color: TestData) -> TestData { + async fn get_opposite_color(self, _: &mut context::Context, color: TestData) -> TestData { match color { TestData::White => TestData::Black, TestData::Black => TestData::White, @@ -53,7 +53,7 @@ async fn test_call() -> anyhow::Result<()> { let client = ColorProtocolClient::new(client::Config::default(), transport).spawn(); let color = client - .get_opposite_color(context::current(), TestData::White) + .get_opposite_color(&mut context::current(), TestData::White) .await?; assert_eq!(color, TestData::Black); diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 06542b43b..1005ae116 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -22,11 +22,11 @@ trait Service { struct Server; impl Service for Server { - async fn add(self, _: context::Context, x: i32, y: i32) -> i32 { + async fn add(self, _: &mut context::Context, x: i32, y: i32) -> i32 { x + y } - async fn hey(self, _: context::Context, name: String) -> String { + async fn hey(self, _: &mut context::Context, name: String) -> String { format!("Hey, {name}.") } } @@ -38,10 +38,10 @@ async fn sequential() { let channel = BaseChannel::with_defaults(rx); tokio::spawn( channel - .execute(tarpc::server::serve(|_, i: u32| async move { Ok(i + 1) })) + .execute(tarpc::server::serve(|_, i: u32| async move { Ok(i + 1) }.boxed())) .for_each(|response| response), ); - assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); + assert_eq!(client.call(&mut context::current(), 1).await.unwrap(), 2); } #[tokio::test] @@ -55,7 +55,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { struct LoopServer; impl Loop for LoopServer { - async fn r#loop(self, _: context::Context) { + async fn r#loop(self, _: &mut context::Context) { loop { futures::pending!(); } @@ -73,7 +73,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { let mut ctx = context::current(); ctx.deadline = Instant::now() + Duration::from_secs(60 * 60); - let _ = client.r#loop(ctx).await; + let _ = client.r#loop(&mut ctx).await; }); let mut requests = BaseChannel::with_defaults(rx).requests(); @@ -112,9 +112,9 @@ async fn serde_tcp() -> anyhow::Result<()> { let transport = serde_transport::tcp::connect(addr, Json::default).await?; let client = ServiceClient::new(client::Config::default(), transport).spawn(); - assert_matches!(client.add(context::current(), 1, 2).await, Ok(3)); + assert_matches!(client.add(&mut context::current(), 1, 2).await, Ok(3)); assert_matches!( - client.hey(context::current(), "Tim".to_string()).await, + client.hey(&mut context::current(), "Tim".to_string()).await, Ok(ref s) if s == "Hey, Tim." ); @@ -145,8 +145,8 @@ async fn serde_uds() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), transport).spawn(); // Save results using socket so we can clean the socket even if our test assertions fail - let res1 = client.add(context::current(), 1, 2).await; - let res2 = client.hey(context::current(), "Tim".to_string()).await; + let res1 = client.add(&mut context::current(), 1, 2).await; + let res2 = client.hey(&mut context::current(), "Tim".to_string()).await; assert_matches!(res1, Ok(3)); assert_matches!(res2, Ok(ref s) if s == "Hey, Tim."); @@ -169,12 +169,15 @@ async fn concurrent() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let req1 = client.add(context::current(), 1, 2); - let req2 = client.add(context::current(), 3, 4); - let req3 = client.hey(context::current(), "Tim".to_string()); + let mut context = context::current(); + let req1 = client.add(&mut context, 1, 2); assert_matches!(req1.await, Ok(3)); + + let req2 = client.add(&mut context, 3, 4); assert_matches!(req2.await, Ok(7)); + + let req3 = client.hey(&mut context, "Tim".to_string()); assert_matches!(req3.await, Ok(ref s) if s == "Hey, Tim."); Ok(()) @@ -195,9 +198,13 @@ async fn concurrent_join() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let req1 = client.add(context::current(), 1, 2); - let req2 = client.add(context::current(), 3, 4); - let req3 = client.hey(context::current(), "Tim".to_string()); + let mut context1 = context::current(); + let mut context2 = context::current(); + let mut context3 = context::current(); + + let req1 = client.add(&mut context1, 1, 2); + let req2 = client.add(&mut context2, 3, 4); + let req3 = client.hey(&mut context3, "Tim".to_string()); let (resp1, resp2, resp3) = join!(req1, req2, req3); assert_matches!(resp1, Ok(3)); @@ -225,8 +232,11 @@ async fn concurrent_join_all() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let req1 = client.add(context::current(), 1, 2); - let req2 = client.add(context::current(), 3, 4); + let mut context1 = context::current(); + let mut context2 = context::current(); + + let req1 = client.add(&mut context1, 1, 2); + let req2 = client.add(&mut context2, 3, 4); let responses = join_all(vec![req1, req2]).await; assert_matches!(responses[0], Ok(3)); @@ -245,7 +255,7 @@ async fn counter() -> anyhow::Result<()> { struct CountService(u32); impl Counter for &mut CountService { - async fn count(self, _: context::Context) -> u32 { + async fn count(self, _: &mut context::Context) -> u32 { self.0 += 1; self.0 } @@ -262,8 +272,8 @@ async fn counter() -> anyhow::Result<()> { }); let client = CounterClient::new(client::Config::default(), tx).spawn(); - assert_matches!(client.count(context::current()).await, Ok(1)); - assert_matches!(client.count(context::current()).await, Ok(2)); + assert_matches!(client.count(&mut context::current()).await, Ok(1)); + assert_matches!(client.count(&mut context::current()).await, Ok(2)); Ok(()) } From 1b605a3c48bfdd61db0467191c66bd23550b9c3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Sun, 23 Nov 2025 20:57:48 +0100 Subject: [PATCH 02/26] run cargo fmt --- tarpc/examples/compression.rs | 4 +++- tarpc/examples/pubsub.rs | 11 ++++++++--- tarpc/examples/readme.rs | 4 +++- tarpc/src/client/stub.rs | 13 ++++++++++--- tarpc/src/server.rs | 31 +++++++++++++++++++++++-------- tarpc/src/transport/channel.rs | 19 +++++++++++-------- tarpc/tests/service_functional.rs | 4 +++- 7 files changed, 61 insertions(+), 25 deletions(-) diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index 783f2618a..e703cc676 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -134,7 +134,9 @@ async fn main() -> anyhow::Result<()> { println!( "{}", - client.hello(&mut context::current(), "friend".into()).await? + client + .hello(&mut context::current(), "friend".into()) + .await? ); Ok(()) } diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index c89f9e736..4e132616f 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -271,11 +271,12 @@ impl publisher::Publisher for Publisher { }; let mut publications = Vec::new(); - for client in subscribers.values_mut() { publications.push(async { let mut context = context::current(); - client.receive(&mut context, topic.clone(), message.clone()).await + client + .receive(&mut context, topic.clone(), message.clone()) + .await }); } // Ignore failing subscribers. In a real pubsub, you'd want to continually retry until @@ -359,7 +360,11 @@ async fn main() -> anyhow::Result<()> { .await?; publisher - .publish(&mut context::current(), "history".into(), "napoleon".to_string()) + .publish( + &mut context::current(), + "history".into(), + "napoleon".to_string(), + ) .await?; drop(_subscriber0); diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index bb3deadc7..c00c270f0 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -46,7 +46,9 @@ async fn main() -> anyhow::Result<()> { // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context // specifies a deadline and trace information which can be helpful in debugging requests. - let hello = client.hello(&mut context::current(), "Stim".to_string()).await?; + let hello = client + .hello(&mut context::current(), "Stim".to_string()) + .await?; println!("{hello}"); diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index 2647c1321..14b6edf30 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -24,8 +24,11 @@ pub trait Stub { type Resp; /// Calls a remote service. - async fn call(&self, ctx: &mut context::Context, request: Self::Req) - -> Result; + async fn call( + &self, + ctx: &mut context::Context, + request: Self::Req, + ) -> Result; } impl Stub for Channel @@ -46,7 +49,11 @@ where { type Req = S::Req; type Resp = S::Resp; - async fn call(&self, ctx: &mut context::Context, req: Self::Req) -> Result { + async fn call( + &self, + ctx: &mut context::Context, + req: Self::Req, + ) -> Result { self.clone().serve(ctx, req).await.map_err(RpcError::Server) } } diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index d0cca7ad4..e08365964 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -76,7 +76,11 @@ pub trait Serve { type Resp; /// Responds to a single request. - async fn serve(self, ctx: &mut context::Context, req: Self::Req) -> Result; + async fn serve( + self, + ctx: &mut context::Context, + req: Self::Req, + ) -> Result; } /// A Serve wrapper around a Fn. @@ -104,7 +108,10 @@ impl Copy for ServeFn where F: Copy {} /// Result>`. pub fn serve(f: F) -> ServeFn where - for<'a> F: FnOnce(&'a mut context::Context, Req) -> Pin> + 'a + Send>>, + for<'a> F: FnOnce( + &'a mut context::Context, + Req, + ) -> Pin> + 'a + Send>>, { ServeFn { f, @@ -115,7 +122,10 @@ where impl Serve for ServeFn where Req: RequestName, - for<'a> F: FnOnce(&'a mut context::Context, Req) -> Pin> + 'a + Send>>, + for<'a> F: FnOnce( + &'a mut context::Context, + Req, + ) -> Pin> + 'a + Send>>, { type Req = Req; type Resp = Resp; @@ -1062,10 +1072,13 @@ mod tests { let some_time = Instant::now() + Duration::from_secs(37); let some_other_time = Instant::now() + Duration::from_secs(83); - let serve = serve(move |ctx: &mut context::Context, i| async move { - assert_eq!(ctx.deadline, some_time); - Ok(i) - }.boxed()); + let serve = serve(move |ctx: &mut context::Context, i| { + async move { + assert_eq!(ctx.deadline, some_time); + Ok(i) + } + .boxed() + }); let deadline_hook = serve.before(SetDeadline(some_time)); let mut ctx = context::current(); ctx.deadline = some_other_time; @@ -1322,7 +1335,9 @@ mod tests { Poll::Ready(Some(Ok(request))) => request, result => panic!("Unexpected result: {result:?}"), }; - request.execute(serve(|_, _| async { Ok(()) }.boxed())).await; + request + .execute(serve(|_, _| async { Ok(()) }.boxed())) + .await; assert!( requests .as_mut() diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 3c0c420aa..e064e6813 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -191,14 +191,17 @@ mod tests { tokio::spawn( stream::once(future::ready(server_channel)) .map(BaseChannel::with_defaults) - .execute(serve(|_ctx, request: String| async move { - request.parse::().map_err(|_| { - ServerError::new( - io::ErrorKind::InvalidInput, - format!("{request:?} is not an int"), - ) - }) - }.boxed())) + .execute(serve(|_ctx, request: String| { + async move { + request.parse::().map_err(|_| { + ServerError::new( + io::ErrorKind::InvalidInput, + format!("{request:?} is not an int"), + ) + }) + } + .boxed() + })) .for_each(|channel| async move { tokio::spawn(channel.for_each(|response| response)); }), diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 1005ae116..f3adda2fb 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -38,7 +38,9 @@ async fn sequential() { let channel = BaseChannel::with_defaults(rx); tokio::spawn( channel - .execute(tarpc::server::serve(|_, i: u32| async move { Ok(i + 1) }.boxed())) + .execute(tarpc::server::serve(|_, i: u32| { + async move { Ok(i + 1) }.boxed() + })) .for_each(|response| response), ); assert_eq!(client.call(&mut context::current(), 1).await.unwrap(), 2); From 02ca335e504c2051ecf58d235d4990310dd81af1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Sun, 23 Nov 2025 20:58:52 +0100 Subject: [PATCH 03/26] cargo clippy --- tarpc/src/client.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 96afc4c5f..9ef7a1acb 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -155,7 +155,7 @@ where .send(DispatchRequest { ctx: context::Context { deadline: ctx.deadline, - trace_context: ctx.trace_context.clone(), + trace_context: ctx.trace_context, }, span, request_id, From 8e1dce47fd473d84ff80b3186a8c8b162965726c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Tue, 18 Nov 2025 14:19:35 +0100 Subject: [PATCH 04/26] separate context into shared, client and server contexts. only transmit shared context between client and server --- example-service/src/client.rs | 5 +- example-service/src/server.rs | 2 +- plugins/src/lib.rs | 10 +- plugins/tests/service.rs | 14 +-- tarpc/examples/compression.rs | 6 +- tarpc/examples/custom_transport.rs | 6 +- tarpc/examples/pubsub.rs | 20 ++-- tarpc/examples/readme.rs | 6 +- tarpc/examples/tls_over_tcp.rs | 6 +- tarpc/examples/tracing.rs | 9 +- tarpc/src/client.rs | 29 +++--- tarpc/src/client/in_flight_requests.rs | 6 +- tarpc/src/client/stub.rs | 23 +++-- tarpc/src/client/stub/load_balance.rs | 10 +- tarpc/src/client/stub/mock.rs | 2 +- tarpc/src/client/stub/retry.rs | 2 +- tarpc/src/context.rs | 98 ++++++++++++++++--- tarpc/src/lib.rs | 8 +- tarpc/src/server.rs | 90 +++++++---------- tarpc/src/server/incoming.rs | 2 +- tarpc/src/server/request_hook.rs | 14 +-- tarpc/src/server/request_hook/after.rs | 8 +- tarpc/src/server/request_hook/before.rs | 18 ++-- .../server/request_hook/before_and_after.rs | 2 +- tarpc/src/server/testing.rs | 2 +- tarpc/src/transport/channel.rs | 4 +- tarpc/tests/dataservice.rs | 4 +- tarpc/tests/service_functional.rs | 36 +++---- 28 files changed, 246 insertions(+), 196 deletions(-) diff --git a/example-service/src/client.rs b/example-service/src/client.rs index c73122c07..dc7104bfd 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -10,6 +10,7 @@ use std::{net::SocketAddr, time::Duration}; use tarpc::{client, context, tokio_serde::formats::Json}; use tokio::time::sleep; use tracing::Instrument; +use tarpc::context::ClientContext; #[derive(Parser)] struct Flags { @@ -34,8 +35,8 @@ async fn main() -> anyhow::Result<()> { let client = WorldClient::new(client::Config::default(), transport.await?).spawn(); let hello = async move { - let mut context = context::current(); - let mut context2 = context::current(); + let mut context = ClientContext::current(); + let mut context2 = ClientContext::current(); // Send the request twice, just to be safe! ;) tokio::select! { diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 1efab549d..0845783c7 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -35,7 +35,7 @@ struct Flags { struct HelloServer(SocketAddr); impl World for HelloServer { - async fn hello(self, _: &mut context::Context, name: String) -> String { + async fn hello(self, _: &mut context::ServerContext, name: String) -> String { let sleep_time = Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng())); time::sleep(sleep_time).await; diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 55ec2730e..886b85b48 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -375,7 +375,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// # Example /// /// ```no_run -/// use tarpc::{client, transport, service, server::{self, Channel}, context::Context}; +/// use tarpc::{client, transport, service, server::{self, Channel}, context::ServerContext}; /// /// #[service] /// pub trait Calculator { @@ -401,7 +401,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// #[derive(Clone)] /// struct CalculatorServer; /// impl Calculator for CalculatorServer { -/// async fn add(self, context: &mut Context, a: i32, b: i32) -> i32 { +/// async fn add(self, context: &mut ServerContext, a: i32, b: i32) -> i32 { /// a + b /// } /// } @@ -558,7 +558,7 @@ impl ServiceGenerator<'_> { )| { quote! { #( #attrs )* - async fn #ident(self, context: &mut ::tarpc::context::Context, #( #args ),*) -> #output; + async fn #ident(self, context: &mut ::tarpc::context::ServerContext, #( #args ),*) -> #output; } }, ); @@ -622,7 +622,7 @@ impl ServiceGenerator<'_> { type Resp = #response_ident; - async fn serve(self, ctx: &mut ::tarpc::context::Context, req: #request_ident) + async fn serve(self, ctx: &mut ::tarpc::context::ServerContext, req: #request_ident) -> ::core::result::Result<#response_ident, ::tarpc::ServerError> { match req { #( @@ -786,7 +786,7 @@ impl ServiceGenerator<'_> { #( #[allow(unused)] #( #method_attrs )* - #vis fn #method_idents<'a>(&'a self, ctx: &'a mut ::tarpc::context::Context, #( #args ),*) + #vis fn #method_idents<'a>(&'a self, ctx: &'a mut ::tarpc::context::ClientContext, #( #args ),*) -> impl ::core::future::Future> + '_ { let request = #request_ident::#camel_case_idents { #( #arg_pats ),* }; let resp = self.0.call(ctx, request); diff --git a/plugins/tests/service.rs b/plugins/tests/service.rs index d38492bd7..b03f3470f 100644 --- a/plugins/tests/service.rs +++ b/plugins/tests/service.rs @@ -12,15 +12,15 @@ fn att_service_trait() { } impl Foo for () { - async fn two_part(self, _: &mut context::Context, s: String, i: i32) -> (String, i32) { + async fn two_part(self, _: &mut context::ServerContext, s: String, i: i32) -> (String, i32) { (s, i) } - async fn bar(self, _: &mut context::Context, s: String) -> String { + async fn bar(self, _: &mut context::ServerContext, s: String) -> String { s } - async fn baz(self, _: &mut context::Context) {} + async fn baz(self, _: &mut context::ServerContext) {} } } @@ -39,18 +39,18 @@ fn raw_idents() { impl r#trait for () { async fn r#await( self, - _: &mut context::Context, + _: &mut context::ServerContext, r#struct: r#yield, r#enum: i32, ) -> (r#yield, i32) { (r#struct, r#enum) } - async fn r#fn(self, _: &mut context::Context, r#impl: r#yield) -> r#yield { + async fn r#fn(self, _: &mut context::ServerContext, r#impl: r#yield) -> r#yield { r#impl } - async fn r#async(self, _: &mut context::Context) {} + async fn r#async(self, _: &mut context::ServerContext) {} } } @@ -64,7 +64,7 @@ fn service_with_cfg_rpc() { } impl Foo for () { - async fn foo(self, _: &mut context::Context) {} + async fn foo(self, _: &mut context::ServerContext) {} } } diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index e703cc676..663236731 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -108,7 +108,7 @@ pub trait World { struct HelloServer; impl World for HelloServer { - async fn hello(self, _: &mut context::Context, name: String) -> String { + async fn hello(self, _: &mut context::ServerContext, name: String) -> String { format!("Hey, {name}!") } } @@ -134,9 +134,7 @@ async fn main() -> anyhow::Result<()> { println!( "{}", - client - .hello(&mut context::current(), "friend".into()) - .await? + client.hello(&mut context::ClientContext::current(), "friend".into()).await? ); Ok(()) } diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index c99825d08..1c682173d 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -5,7 +5,7 @@ // https://opensource.org/licenses/MIT. use futures::prelude::*; -use tarpc::context::Context; +use tarpc::context::{ClientContext, ServerContext}; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; @@ -21,7 +21,7 @@ pub trait PingService { struct Service; impl PingService for Service { - async fn ping(self, _: &mut Context) {} + async fn ping(self, _: &mut ServerContext) {} } #[tokio::main] @@ -52,7 +52,7 @@ async fn main() -> anyhow::Result<()> { let transport = transport::new(codec_builder.new_framed(conn), Bincode::default()); PingServiceClient::new(Default::default(), transport) .spawn() - .ping(&mut tarpc::context::current()) + .ping(&mut ClientContext::current()) .await?; Ok(()) diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 4e132616f..83c1371b9 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -80,11 +80,11 @@ struct Subscriber { } impl subscriber::Subscriber for Subscriber { - async fn topics(self, _: &mut context::Context) -> Vec { + async fn topics(self, _: &mut context::ServerContext) -> Vec { self.topics.clone() } - async fn receive(self, _: &mut context::Context, topic: String, message: String) { + async fn receive(self, _: &mut context::ServerContext, topic: String, message: String) { info!(local_addr = %self.local_addr, %topic, %message, "ReceivedMessage") } } @@ -210,7 +210,7 @@ impl Publisher { subscriber: subscriber::SubscriberClient, ) { // Populate the topics - if let Ok(topics) = subscriber.topics(&mut context::current()).await { + if let Ok(topics) = subscriber.topics(&mut context::ClientContext::current()).await { self.clients.lock().unwrap().insert( subscriber_addr, Subscription { @@ -263,7 +263,7 @@ impl Publisher { } impl publisher::Publisher for Publisher { - async fn publish(self, _: &mut context::Context, topic: String, message: String) { + async fn publish(self, _: &mut context::ServerContext, topic: String, message: String) { info!("received message to publish."); let mut subscribers = match self.subscriptions.read().unwrap().get(&topic) { None => return, @@ -271,12 +271,10 @@ impl publisher::Publisher for Publisher { }; let mut publications = Vec::new(); + for client in subscribers.values_mut() { publications.push(async { - let mut context = context::current(); - client - .receive(&mut context, topic.clone(), message.clone()) - .await + client.receive(&mut context::ClientContext::current(), topic.clone(), message.clone()).await }); } // Ignore failing subscribers. In a real pubsub, you'd want to continually retry until @@ -360,11 +358,7 @@ async fn main() -> anyhow::Result<()> { .await?; publisher - .publish( - &mut context::current(), - "history".into(), - "napoleon".to_string(), - ) + .publish(&mut context::current(), "history".into(), "napoleon".to_string()) .await?; drop(_subscriber0); diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index c00c270f0..60daf4e45 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -23,7 +23,7 @@ pub trait World { struct HelloServer; impl World for HelloServer { - async fn hello(self, _: &mut context::Context, name: String) -> String { + async fn hello(self, _: &mut context::ServerContext, name: String) -> String { format!("Hello, {name}!") } } @@ -46,9 +46,7 @@ async fn main() -> anyhow::Result<()> { // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context // specifies a deadline and trace information which can be helpful in debugging requests. - let hello = client - .hello(&mut context::current(), "Stim".to_string()) - .await?; + let hello = client.hello(&mut context::ClientContext::current(), "Stim".to_string()).await?; println!("{hello}"); diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index d81ea74a1..cc3c1690b 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -18,7 +18,7 @@ use tokio_rustls::rustls::{ }; use tokio_rustls::{TlsAcceptor, TlsConnector}; -use tarpc::context::Context; +use tarpc::context::{ClientContext, ServerContext}; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; @@ -33,7 +33,7 @@ pub trait PingService { struct Service; impl PingService for Service { - async fn ping(self, _: &mut Context) -> String { + async fn ping(self, _: &mut ServerContext) -> String { "🔒".to_owned() } } @@ -146,7 +146,7 @@ async fn main() -> anyhow::Result<()> { let transport = transport::new(codec_builder.new_framed(stream), Bincode::default()); let answer = PingServiceClient::new(Default::default(), transport) .spawn() - .ping(&mut tarpc::context::current()) + .ping(&mut ClientContext::current()) .await?; println!("ping answer: {answer}"); diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index 1bace43ce..be1b539c1 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -56,7 +56,7 @@ pub mod double { struct AddServer; impl AddService for AddServer { - async fn add(self, _: &mut context::Context, x: i32, y: i32) -> i32 { + async fn add(self, _: &mut context::ServerContext, x: i32, y: i32) -> i32 { x + y } } @@ -70,9 +70,9 @@ impl DoubleService for DoubleServer where Stub: AddStub + Clone + Send + Sync + 'static, { - async fn double(self, _: &mut context::Context, x: i32) -> Result { + async fn double(self, _: &mut context::ServerContext, x: i32) -> Result { self.add_client - .add(&mut context::current(), x, x) + .add(&mut context::ClientContext::current(), x, x) .await .map_err(|e| e.to_string()) } @@ -193,9 +193,8 @@ async fn main() -> anyhow::Result<()> { let double_client = double::DoubleClient::new(client::Config::default(), to_double_server).spawn(); - let mut ctx = context::current(); for _ in 1..=5 { - tracing::info!("{:?}", double_client.double(&mut ctx, 1).await?); + tracing::info!("{:?}", double_client.double(&mut context::ClientContext::current(), 1).await?); } tracer_provider.shutdown()?; diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 9ef7a1acb..8d3b9f4a7 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -128,7 +128,7 @@ where otel.kind = "client", otel.name = %request.name()) )] - pub async fn call(&self, ctx: &mut context::Context, request: Req) -> Result { + pub async fn call(&self, ctx: &mut context::SharedContext, request: Req) -> Result { let span = Span::current(); ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { tracing::trace!( @@ -153,10 +153,7 @@ where }; self.to_dispatch .send(DispatchRequest { - ctx: context::Context { - deadline: ctx.deadline, - trace_context: ctx.trace_context, - }, + ctx: ctx.clone(), span, request_id, request, @@ -460,7 +457,7 @@ where fn poll_next_cancellation( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>>> { + ) -> Poll>>> { ready!(self.ensure_writeable(cx)?); loop { @@ -516,13 +513,15 @@ where let request = ClientMessage::Request(Request { id: request_id, message: request, - context: context::Context { - deadline: ctx.deadline, - trace_context: ctx.trace_context, - }, + context: ctx.clone(), }); + + //TODO: Feels like we could avoid either saving the request context in insert_request + // or submitting the context in start_request. + let full_context = context::ClientContext::new(ctx); + self.in_flight_requests() - .insert_request(request_id, ctx, span.clone(), response_completion) + .insert_request(request_id, full_context, span.clone(), response_completion) .expect("Request IDs should be unique"); match self.start_send(request) { Ok(()) => tracing::debug!("SendRequest"), @@ -717,7 +716,7 @@ mod tests { dispatch .in_flight_requests - .insert_request(0, context::current(), Span::current(), tx) + .insert_request(0, ClientContext::current(), Span::current(), tx) .unwrap(); server_channel .send(Response { @@ -884,7 +883,7 @@ mod tests { let (dispatch, channel, _server_channel) = set_up(); drop(dispatch); // error on send - let resp = channel.call(&mut current(), "hi".to_string()).await; + let resp = channel.call(&mut ClientContext::current(), "hi".to_string()).await; assert_matches!(resp, Err(RpcError::Shutdown)); } @@ -1094,7 +1093,7 @@ mod tests { let request_id = u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); let request = DispatchRequest { - ctx: context::current(), + ctx: SharedContext::current(), span: Span::current(), request_id, request: request.to_string(), @@ -1119,7 +1118,7 @@ mod tests { let request_id = u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); let request = DispatchRequest { - ctx: context::current(), + ctx: SharedContext::current(), span: Span::current(), request_id, request: request.to_string(), diff --git a/tarpc/src/client/in_flight_requests.rs b/tarpc/src/client/in_flight_requests.rs index 1776a74a0..a368a5a48 100644 --- a/tarpc/src/client/in_flight_requests.rs +++ b/tarpc/src/client/in_flight_requests.rs @@ -29,7 +29,7 @@ impl Default for InFlightRequests { #[derive(Debug)] struct RequestData { - ctx: context::Context, + ctx: context::ClientContext, span: Span, response_completion: oneshot::Sender, /// The key to remove the timer for the request's deadline. @@ -56,7 +56,7 @@ impl InFlightRequests { pub fn insert_request( &mut self, request_id: u64, - ctx: context::Context, + ctx: context::ClientContext, span: Span, response_completion: oneshot::Sender, ) -> Result<(), AlreadyExistsError> { @@ -106,7 +106,7 @@ impl InFlightRequests { /// Cancels a request without completing (typically used when a request handle was dropped /// before the request completed). - pub fn cancel_request(&mut self, request_id: u64) -> Option<(context::Context, Span)> { + pub fn cancel_request(&mut self, request_id: u64) -> Option<(context::ClientContext, Span)> { if let Some(request_data) = self.request_data.remove(&request_id) { self.request_data.compact(0.1); self.deadlines.remove(&request_data.deadline_key); diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index 14b6edf30..c7dc12008 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -24,11 +24,8 @@ pub trait Stub { type Resp; /// Calls a remote service. - async fn call( - &self, - ctx: &mut context::Context, - request: Self::Req, - ) -> Result; + async fn call(&self, ctx: &mut context::ClientContext, request: Self::Req) + -> Result; } impl Stub for Channel @@ -38,7 +35,7 @@ where type Req = Req; type Resp = Resp; - async fn call(&self, ctx: &mut context::Context, request: Req) -> Result { + async fn call(&self, ctx: &mut context::ClientContext, request: Req) -> Result { Self::call(self, ctx, request).await } } @@ -49,11 +46,13 @@ where { type Req = S::Req; type Resp = S::Resp; - async fn call( - &self, - ctx: &mut context::Context, - req: Self::Req, - ) -> Result { - self.clone().serve(ctx, req).await.map_err(RpcError::Server) + async fn call(&self, ctx: &mut context::ClientContext, req: Self::Req) -> Result { + let mut server_ctx = context::ServerContext::new(ctx.shared_context.clone()); + + let res = self.clone().serve(&mut server_ctx, req).await.map_err(RpcError::Server); + + ctx.shared_context = server_ctx.shared_context; + + res } } diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs index 6c0f7b0df..bf70ebe2a 100644 --- a/tarpc/src/client/stub/load_balance.rs +++ b/tarpc/src/client/stub/load_balance.rs @@ -20,7 +20,7 @@ mod round_robin { async fn call( &self, - ctx: &mut context::Context, + ctx: &mut context::ClientContext, request: Self::Req, ) -> Result { let next = self.stubs.next(); @@ -119,7 +119,7 @@ mod consistent_hash { async fn call( &self, - ctx: &mut context::Context, + ctx: &mut context::ClientContext, request: Self::Req, ) -> Result { let index = usize::try_from(self.hasher.hash_one(&request) % self.stubs_len).expect( @@ -200,13 +200,13 @@ mod consistent_hash { )?; for _ in 0..2 { - let resp = stub.call(&mut context::current(), 'a').await?; + let resp = stub.call(&mut context::ClientContext::current(), 'a').await?; assert_eq!(resp, 1); - let resp = stub.call(&mut context::current(), 'b').await?; + let resp = stub.call(&mut context::ClientContext::current(), 'b').await?; assert_eq!(resp, 2); - let resp = stub.call(&mut context::current(), 'c').await?; + let resp = stub.call(&mut context::ClientContext::current(), 'c').await?; assert_eq!(resp, 3); } diff --git a/tarpc/src/client/stub/mock.rs b/tarpc/src/client/stub/mock.rs index 6f0540797..451544433 100644 --- a/tarpc/src/client/stub/mock.rs +++ b/tarpc/src/client/stub/mock.rs @@ -30,7 +30,7 @@ where type Req = Req; type Resp = Resp; - async fn call(&self, _: &mut context::Context, request: Self::Req) -> Result { + async fn call(&self, _: &mut context::ClientContext, request: Self::Req) -> Result { self.responses .get(&request) .cloned() diff --git a/tarpc/src/client/stub/retry.rs b/tarpc/src/client/stub/retry.rs index 18c84f25f..d93daa156 100644 --- a/tarpc/src/client/stub/retry.rs +++ b/tarpc/src/client/stub/retry.rs @@ -18,7 +18,7 @@ where async fn call( &self, - ctx: &mut context::Context, + ctx: &mut context::ClientContext, request: Self::Req, ) -> Result { let request = Arc::new(request); diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index f59d34dd9..a96c49095 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -14,6 +14,7 @@ use std::{ convert::TryFrom, time::{Duration, Instant}, }; +use std::ops::{Deref, DerefMut}; use tracing_opentelemetry::OpenTelemetrySpanExt; /// A request context that carries request-scoped information like deadlines and trace information. @@ -21,10 +22,10 @@ use tracing_opentelemetry::OpenTelemetrySpanExt; /// /// The context should not be stored directly in a server implementation, because the context will /// be different for each request in scope. -#[derive(Debug)] +#[derive(Debug, Clone)] #[non_exhaustive] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Context { +pub struct SharedContext { /// When the client expects the request to be complete by. The server should cancel the request /// if it is not complete by this time. #[cfg_attr(feature = "serde1", serde(default = "ten_seconds_from_now"))] @@ -38,6 +39,86 @@ pub struct Context { pub trace_context: trace::Context, } +/// Request context that carries request-scoped server side information like deadlines and trace information +/// as well as any server side extensions defined by the transport, hooks or service implementations. +/// It is build from the shared context sent from client to server. +/// +/// The context should not be stored directly in a server implementation, because the context will +/// be different for each request in scope. +#[derive(Debug)] +pub struct ServerContext { + /// Shared context sent from client to server which contains information used by both sides. + pub shared_context: SharedContext, +} + +impl ServerContext { + /// Creates a new ServerContext from the given SharedContext with no extensions. + pub fn new(shared_context: SharedContext) -> Self { + Self { + shared_context, + } + } + + /// Creates a new ServerContext for the current shared context with no extensions. + pub fn current() -> Self { + Self::new(SharedContext::current()) + } +} + +impl Deref for ServerContext { + type Target = SharedContext; + + fn deref(&self) -> &Self::Target { + &self.shared_context + } +} +impl DerefMut for ServerContext { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.shared_context + } +} + +/// Request context that carries request-scoped client side information like deadlines and trace information +/// as well as any server side extensions defined by the transport, hooks and stubs. +/// The shared part of the context is sent from client to server, while the client side extensions are only seen on the client side. +/// +/// The context should not be stored directly in a stub implementation, because the context will +/// be different for each request in scope. +#[derive(Debug)] +pub struct ClientContext { + /// Shared context sent from client to server which contains information used by both sides. + pub shared_context: SharedContext, + +} + +impl ClientContext { + /// Creates a new ServerContext from the given SharedContext with no extensions. + pub fn new(shared_context: SharedContext) -> Self { + Self { + shared_context, + } + } + + /// Creates a new ServerContext for the current shared context with no extensions. + pub fn current() -> Self { + Self::new(SharedContext::current()) + } +} + +impl Deref for ClientContext { + type Target = SharedContext; + + fn deref(&self) -> &Self::Target { + &self.shared_context + } +} + +impl DerefMut for ClientContext { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.shared_context + } +} + #[cfg(feature = "serde1")] mod absolute_to_relative_time { pub use serde::{Deserialize, Deserializer, Serialize, Serializer}; @@ -91,17 +172,12 @@ mod absolute_to_relative_time { } } -assert_impl_all!(Context: Send, Sync); +assert_impl_all!(SharedContext: Send, Sync); fn ten_seconds_from_now() -> Instant { Instant::now() + Duration::from_secs(10) } -/// Returns the context for the current request, or a default Context if no request is active. -pub fn current() -> Context { - Context::current() -} - #[derive(Clone)] struct Deadline(Instant); @@ -111,7 +187,7 @@ impl Default for Deadline { } } -impl Context { +impl SharedContext { /// Returns the context for the current request, or a default Context if no request is active. pub fn current() -> Self { let span = tracing::Span::current(); @@ -137,11 +213,11 @@ impl Context { pub(crate) trait SpanExt { /// Sets the given context on this span. Newly-created spans will be children of the given /// context's trace context. - fn set_context(&self, context: &Context); + fn set_context(&self, context: &SharedContext); } impl SpanExt for tracing::Span { - fn set_context(&self, context: &Context) { + fn set_context(&self, context: &SharedContext) { self.set_parent( opentelemetry::Context::new() .with_remote_span_context(opentelemetry::trace::SpanContext::new( diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 17a06ec57..a83efae02 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -125,7 +125,7 @@ //! //! impl World for HelloServer { //! // Each defined rpc generates an async fn that serves the RPC -//! async fn hello(self, _: &mut context::Context, name: String) -> String { +//! async fn hello(self, _: &mut context::ServerContext, name: String) -> String { //! format!("Hello, {name}!") //! } //! } @@ -158,7 +158,7 @@ //! # struct HelloServer; //! # impl World for HelloServer { //! // Each defined rpc generates an async fn that serves the RPC -//! # async fn hello(self, _: &mut context::Context, name: String) -> String { +//! # async fn hello(self, _: &mut context::ServerContext, name: String) -> String { //! # format!("Hello, {name}!") //! # } //! # } @@ -184,7 +184,7 @@ //! // The client has an RPC method for each RPC defined in the annotated trait. It takes the same //! // args as defined, with the addition of a Context, which is always the first arg. The Context //! // specifies a deadline and trace information which can be helpful in debugging requests. -//! let mut context = context::current(); +//! let mut context = context::ClientContext::current(); //! let hello = client.hello(&mut context, "Stim".to_string()).await?; //! //! println!("{hello}"); @@ -284,7 +284,7 @@ pub enum ClientMessage { #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Request { /// Trace context, deadline, and other cross-cutting concerns. - pub context: context::Context, + pub context: context::SharedContext, /// Uniquely identifies the request across all requests sent over a single channel. pub id: u64, /// The request body. diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index e08365964..3b01d207d 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -76,11 +76,7 @@ pub trait Serve { type Resp; /// Responds to a single request. - async fn serve( - self, - ctx: &mut context::Context, - req: Self::Req, - ) -> Result; + async fn serve(self, ctx: &mut context::ServerContext, req: Self::Req) -> Result; } /// A Serve wrapper around a Fn. @@ -108,10 +104,7 @@ impl Copy for ServeFn where F: Copy {} /// Result>`. pub fn serve(f: F) -> ServeFn where - for<'a> F: FnOnce( - &'a mut context::Context, - Req, - ) -> Pin> + 'a + Send>>, + for<'a> F: FnOnce(&'a mut context::ServerContext, Req) -> Pin> + 'a + Send>>, { ServeFn { f, @@ -122,15 +115,12 @@ where impl Serve for ServeFn where Req: RequestName, - for<'a> F: FnOnce( - &'a mut context::Context, - Req, - ) -> Pin> + 'a + Send>>, + for<'a> F: FnOnce(&'a mut context::ServerContext, Req) -> Pin> + 'a + Send>>, { type Req = Req; type Resp = Resp; - async fn serve(self, ctx: &mut context::Context, req: Req) -> Result { + async fn serve(self, ctx: &mut context::ServerContext, req: Req) -> Result { (self.f)(ctx, req).await } } @@ -371,7 +361,7 @@ where /// tokio::spawn(request.execute(serve(|_, i| async move { Ok(i + 1) }.boxed()))); /// } /// }); - /// let mut context = context::current(); + /// let mut context = context::ClientContext::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` @@ -412,7 +402,7 @@ where /// .for_each(|response| async move { /// tokio::spawn(response); /// }.boxed())); - /// let mut context = context::current(); + /// let mut context = context::ClientContext::current(); /// assert_eq!( /// client.call(&mut context, 1).await.unwrap(), /// 2); @@ -762,7 +752,7 @@ where /// .for_each(|response| async move { /// tokio::spawn(response); /// }.boxed())); - /// let mut context = context::current(); + /// let mut context = context::ClientContext::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` @@ -869,7 +859,7 @@ impl InFlightRequest { /// in_flight_request.execute(serve(|_, i| async move { Ok(i + 1) }.boxed())).await; /// } /// }); - /// let mut context = context::current(); + /// let mut context = context::ClientContext::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` @@ -886,15 +876,16 @@ impl InFlightRequest { span, request: Request { - mut context, + context, message, id: request_id, }, } = self; span.record("otel.name", message.name()); + let mut full_context = context::ServerContext::new(context); let _ = Abortable::new( async move { - let message = serve.serve(&mut context, message).await; + let message = serve.serve(&mut full_context, message).await; tracing::debug!("CompleteRequest"); let response = Response { request_id, @@ -1037,7 +1028,7 @@ mod tests { fn fake_request(req: Req) -> ClientMessage { ClientMessage::Request(Request { - context: context::current(), + context: context::SharedContext::current(), id: 0, message: req, }) @@ -1052,7 +1043,7 @@ mod tests { #[tokio::test] async fn test_serve() { let serve = serve(|_, i| async move { Ok(i) }.boxed()); - assert_matches!(serve.serve(&mut context::current(), 7).await, Ok(7)); + assert_matches!(serve.serve(&mut context::ServerContext::current(), 7).await, Ok(7)); } #[tokio::test] @@ -1061,7 +1052,7 @@ mod tests { impl BeforeRequest for SetDeadline { async fn before( &mut self, - ctx: &mut context::Context, + ctx: &mut context::ServerContext, _: &Req, ) -> Result<(), ServerError> { ctx.deadline = self.0; @@ -1072,15 +1063,12 @@ mod tests { let some_time = Instant::now() + Duration::from_secs(37); let some_other_time = Instant::now() + Duration::from_secs(83); - let serve = serve(move |ctx: &mut context::Context, i| { - async move { - assert_eq!(ctx.deadline, some_time); - Ok(i) - } - .boxed() - }); + let serve = serve(move |ctx: &mut context::ServerContext, i| async move { + assert_eq!(ctx.deadline, some_time); + Ok(i) + }.boxed()); let deadline_hook = serve.before(SetDeadline(some_time)); - let mut ctx = context::current(); + let mut ctx = context::ServerContext::current(); ctx.deadline = some_other_time; deadline_hook.serve(&mut ctx, 7).await?; Ok(()) @@ -1103,7 +1091,7 @@ mod tests { impl BeforeRequest for PrintLatency { async fn before( &mut self, - _: &mut context::Context, + _: &mut context::ServerContext, _: &Req, ) -> Result<(), ServerError> { self.start = Instant::now(); @@ -1111,15 +1099,15 @@ mod tests { } } impl AfterRequest for PrintLatency { - async fn after(&mut self, _: &mut context::Context, _: &mut Result) { + async fn after(&mut self, _: &mut context::ServerContext, _: &mut Result) { tracing::debug!("Elapsed: {:?}", self.start.elapsed()); } } - let serve = serve(move |_: &mut context::Context, i| async move { Ok(i) }.boxed()); + let serve = serve(move |_: &mut context::ServerContext, i| async move { Ok(i) }.boxed()); serve .before_and_after(PrintLatency::new()) - .serve(&mut context::current(), 7) + .serve(&mut context::ServerContext::current(), 7) .await?; Ok(()) } @@ -1127,10 +1115,10 @@ mod tests { #[tokio::test] async fn serve_before_error_aborts_request() -> anyhow::Result<()> { let serve = serve(|_, _| async { panic!("Shouldn't get here") }.boxed()); - let deadline_hook = serve.before(|_: &mut context::Context, _: &i32| async { + let deadline_hook = serve.before(|_: &mut context::ServerContext, _: &i32| async { Err(ServerError::new(io::ErrorKind::Other, "oops".into())) }); - let resp: Result = deadline_hook.serve(&mut context::current(), 7).await; + let resp: Result = deadline_hook.serve(&mut context::ServerContext::current(), 7).await; assert_matches!(resp, Err(_)); Ok(()) } @@ -1143,14 +1131,14 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); assert_matches!( channel.as_mut().start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: () }), Err(AlreadyExistsError) @@ -1166,7 +1154,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1174,7 +1162,7 @@ mod tests { .as_mut() .start_request(Request { id: 1, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1197,7 +1185,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1226,7 +1214,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1268,7 +1256,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1291,7 +1279,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1335,9 +1323,7 @@ mod tests { Poll::Ready(Some(Ok(request))) => request, result => panic!("Unexpected result: {result:?}"), }; - request - .execute(serve(|_, _| async { Ok(()) }.boxed())) - .await; + request.execute(serve(|_, _| async { Ok(()) }.boxed())).await; assert!( requests .as_mut() @@ -1358,7 +1344,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1388,7 +1374,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 1, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1409,7 +1395,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1428,7 +1414,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 1, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index eddf3794e..cb01021f5 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -65,7 +65,7 @@ where /// BaseChannel::new(server::Config::default(), rx) /// }).execute(serve(|_, i| async move { Ok(i + 1) }.boxed())); /// tokio::spawn(spawn_incoming(incoming)); -/// let mut context = context::current(); +/// let mut context = context::ClientContext::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` diff --git a/tarpc/src/server/request_hook.rs b/tarpc/src/server/request_hook.rs index 64b97453a..38b0998bf 100644 --- a/tarpc/src/server/request_hook.rs +++ b/tarpc/src/server/request_hook.rs @@ -48,7 +48,7 @@ pub trait RequestHook: Serve { /// use std::io; /// /// let serve = serve(|_ctx, i| async move { Ok(i + 1) }.boxed()) - /// .before(|_ctx: &mut context::Context, req: &i32| { + /// .before(|_ctx: &mut context::ServerContext, req: &i32| { /// future::ready( /// if *req == 1 { /// Err(ServerError::new( @@ -58,7 +58,7 @@ pub trait RequestHook: Serve { /// Ok(()) /// }) /// }); - /// let mut context = context::current(); + /// let mut context = context::ServerContext::current(); /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` @@ -95,13 +95,13 @@ pub trait RequestHook: Serve { /// Ok(i + 1) /// } /// }.boxed()) - /// .after(|_ctx: &mut context::Context, resp: &mut Result| { + /// .after(|_ctx: &mut context::ServerContext, resp: &mut Result| { /// if let Err(e) = resp { /// eprintln!("server error: {e:?}"); /// } /// future::ready(()) /// }); - /// let mut context = context::current(); + /// let mut context = context::ServerContext::current(); /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` @@ -134,7 +134,7 @@ pub trait RequestHook: Serve { /// struct PrintLatency(Instant); /// /// impl BeforeRequest for PrintLatency { - /// async fn before(&mut self, _: &mut context::Context, _: &Req) -> Result<(), ServerError> { + /// async fn before(&mut self, _: &mut context::ServerContext, _: &Req) -> Result<(), ServerError> { /// self.0 = Instant::now(); /// Ok(()) /// } @@ -143,7 +143,7 @@ pub trait RequestHook: Serve { /// impl AfterRequest for PrintLatency { /// async fn after( /// &mut self, - /// _: &mut context::Context, + /// _: &mut context::ServerContext, /// _: &mut Result, /// ) { /// tracing::info!("Elapsed: {:?}", self.0.elapsed()); @@ -153,7 +153,7 @@ pub trait RequestHook: Serve { /// let serve = serve(|_ctx, i| async move { /// Ok(i + 1) /// }.boxed()).before_and_after(PrintLatency(Instant::now())); - /// let mut context = context::current(); + /// let mut context = context::ServerContext::current(); /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_ok()); /// ``` diff --git a/tarpc/src/server/request_hook/after.rs b/tarpc/src/server/request_hook/after.rs index e2c49b2f1..d9e676ca4 100644 --- a/tarpc/src/server/request_hook/after.rs +++ b/tarpc/src/server/request_hook/after.rs @@ -15,15 +15,15 @@ pub trait AfterRequest { /// The function that is called after request execution. /// /// The hook can modify the request context and the response. - async fn after(&mut self, ctx: &mut context::Context, resp: &mut Result); + async fn after(&mut self, ctx: &mut context::ServerContext, resp: &mut Result); } impl AfterRequest for F where - F: FnMut(&mut context::Context, &mut Result) -> Fut, + F: FnMut(&mut context::ServerContext, &mut Result) -> Fut, Fut: Future, { - async fn after(&mut self, ctx: &mut context::Context, resp: &mut Result) { + async fn after(&mut self, ctx: &mut context::ServerContext, resp: &mut Result) { self(ctx, resp).await } } @@ -59,7 +59,7 @@ where async fn serve( self, - ctx: &mut context::Context, + ctx: &mut context::ServerContext, req: Serv::Req, ) -> Result { let ServeThenHook { diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs index ad04cc784..4a1b2ad8a 100644 --- a/tarpc/src/server/request_hook/before.rs +++ b/tarpc/src/server/request_hook/before.rs @@ -19,7 +19,7 @@ pub trait BeforeRequest { /// /// This function can also modify the request context. This could be used, for example, to /// enforce a maximum deadline on all requests. - async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError>; + async fn before(&mut self, ctx: &mut context::ServerContext, req: &Req) -> Result<(), ServerError>; } /// A list of hooks that run in order before request execution. @@ -34,7 +34,7 @@ pub trait BeforeRequestList: BeforeRequest { /// Same as `then`, but helps the compiler with type inference when Next is a closure. fn then_fn< - Next: FnMut(&mut context::Context, &Req) -> Fut, + Next: FnMut(&mut context::ServerContext, &Req) -> Fut, Fut: Future>, >( self, @@ -56,10 +56,10 @@ pub trait BeforeRequestList: BeforeRequest { impl BeforeRequest for F where - F: FnMut(&mut context::Context, &Req) -> Fut, + F: FnMut(&mut context::ServerContext, &Req) -> Fut, Fut: Future>, { - async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError> { + async fn before(&mut self, ctx: &mut context::ServerContext, req: &Req) -> Result<(), ServerError> { self(ctx, req).await } } @@ -87,7 +87,7 @@ where async fn serve( self, - ctx: &mut context::Context, + ctx: &mut context::ServerContext, req: Self::Req, ) -> Result { let HookThenServe { @@ -121,7 +121,7 @@ where /// Ok(()) /// }) /// .serving(serve(|_ctx, i| async move { Ok(i + 1) }.boxed())); -/// let mut context = context::current(); +/// let mut context = context::ServerContext::current(); /// let response = serve.clone().serve(&mut context, 1); /// assert!(block_on(response).is_ok()); /// assert!(i.get() == 2); @@ -141,7 +141,7 @@ pub struct BeforeRequestNil; impl, Rest: BeforeRequest> BeforeRequest for BeforeRequestCons { - async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError> { + async fn before(&mut self, ctx: &mut context::ServerContext, req: &Req) -> Result<(), ServerError> { let BeforeRequestCons(first, rest) = self; first.before(ctx, req).await?; rest.before(ctx, req).await?; @@ -150,7 +150,7 @@ impl, Rest: BeforeRequest> BeforeRequest BeforeRequest for BeforeRequestNil { - async fn before(&mut self, _: &mut context::Context, _: &Req) -> Result<(), ServerError> { + async fn before(&mut self, _: &mut context::ServerContext, _: &Req) -> Result<(), ServerError> { Ok(()) } } @@ -211,7 +211,7 @@ fn before_request_list() { Ok(()) }) .serving(serve(|_ctx, i| async move { Ok(i + 1) }.boxed())); - let mut context = context::current(); + let mut context = context::ServerContext::current(); let response = serve.clone().serve(&mut context, 1); assert!(block_on(response).is_ok()); assert!(i.get() == 2); diff --git a/tarpc/src/server/request_hook/before_and_after.rs b/tarpc/src/server/request_hook/before_and_after.rs index e06f34113..af37427af 100644 --- a/tarpc/src/server/request_hook/before_and_after.rs +++ b/tarpc/src/server/request_hook/before_and_after.rs @@ -46,7 +46,7 @@ where type Req = Req; type Resp = Resp; - async fn serve(self, ctx: &mut context::Context, req: Req) -> Result { + async fn serve(self, ctx: &mut context::ServerContext, req: Req) -> Result { let HookThenServeThenHook { serve, mut hook, .. } = self; diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index db167c42e..70c4e7f69 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -92,7 +92,7 @@ impl FakeChannel>, Response> { let (request_cancellation, _) = cancellations(); self.stream.push_back(Ok(TrackedRequest { request: Request { - context: context::Context { + context: context::SharedContext { deadline: Instant::now(), trace_context: Default::default(), }, diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index e064e6813..5cb897569 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -209,8 +209,8 @@ mod tests { let client = client::new(client::Config::default(), client_channel).spawn(); - let response1 = client.call(&mut context::current(), "123".into()).await; - let response2 = client.call(&mut context::current(), "abc".into()).await; + let response1 = client.call(&mut context::ClientContext::current(), "123".into()).await; + let response2 = client.call(&mut context::ClientContext::current(), "abc".into()).await; trace!("response1: {:?}, response2: {:?}", response1, response2); diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index e051b434e..e4cbf338d 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -22,7 +22,7 @@ pub trait ColorProtocol { struct ColorServer; impl ColorProtocol for ColorServer { - async fn get_opposite_color(self, _: &mut context::Context, color: TestData) -> TestData { + async fn get_opposite_color(self, _: &mut context::ServerContext, color: TestData) -> TestData { match color { TestData::White => TestData::Black, TestData::Black => TestData::White, @@ -53,7 +53,7 @@ async fn test_call() -> anyhow::Result<()> { let client = ColorProtocolClient::new(client::Config::default(), transport).spawn(); let color = client - .get_opposite_color(&mut context::current(), TestData::White) + .get_opposite_color(&mut context::ClientContext::current(), TestData::White) .await?; assert_eq!(color, TestData::Black); diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index f3adda2fb..46ce7bd47 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -22,11 +22,11 @@ trait Service { struct Server; impl Service for Server { - async fn add(self, _: &mut context::Context, x: i32, y: i32) -> i32 { + async fn add(self, _: &mut context::ServerContext, x: i32, y: i32) -> i32 { x + y } - async fn hey(self, _: &mut context::Context, name: String) -> String { + async fn hey(self, _: &mut context::ServerContext, name: String) -> String { format!("Hey, {name}.") } } @@ -43,7 +43,7 @@ async fn sequential() { })) .for_each(|response| response), ); - assert_eq!(client.call(&mut context::current(), 1).await.unwrap(), 2); + assert_eq!(client.call(&mut context::ClientContext::current(), 1).await.unwrap(), 2); } #[tokio::test] @@ -57,7 +57,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { struct LoopServer; impl Loop for LoopServer { - async fn r#loop(self, _: &mut context::Context) { + async fn r#loop(self, _: &mut context::ServerContext) { loop { futures::pending!(); } @@ -73,7 +73,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { tokio::spawn(async move { let client = LoopClient::new(client::Config::default(), tx).spawn(); - let mut ctx = context::current(); + let mut ctx = context::ClientContext::current(); ctx.deadline = Instant::now() + Duration::from_secs(60 * 60); let _ = client.r#loop(&mut ctx).await; }); @@ -114,9 +114,9 @@ async fn serde_tcp() -> anyhow::Result<()> { let transport = serde_transport::tcp::connect(addr, Json::default).await?; let client = ServiceClient::new(client::Config::default(), transport).spawn(); - assert_matches!(client.add(&mut context::current(), 1, 2).await, Ok(3)); + assert_matches!(client.add(&mut context::ClientContext::current(), 1, 2).await, Ok(3)); assert_matches!( - client.hey(&mut context::current(), "Tim".to_string()).await, + client.hey(&mut context::ClientContext::current(), "Tim".to_string()).await, Ok(ref s) if s == "Hey, Tim." ); @@ -147,8 +147,8 @@ async fn serde_uds() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), transport).spawn(); // Save results using socket so we can clean the socket even if our test assertions fail - let res1 = client.add(&mut context::current(), 1, 2).await; - let res2 = client.hey(&mut context::current(), "Tim".to_string()).await; + let res1 = client.add(&mut context::ClientContext::current(), 1, 2).await; + let res2 = client.hey(&mut context::ClientContext::current(), "Tim".to_string()).await; assert_matches!(res1, Ok(3)); assert_matches!(res2, Ok(ref s) if s == "Hey, Tim."); @@ -171,7 +171,7 @@ async fn concurrent() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let mut context = context::current(); + let mut context = context::ClientContext::current(); let req1 = client.add(&mut context, 1, 2); assert_matches!(req1.await, Ok(3)); @@ -200,9 +200,9 @@ async fn concurrent_join() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let mut context1 = context::current(); - let mut context2 = context::current(); - let mut context3 = context::current(); + let mut context1 = context::ClientContext::current(); + let mut context2 = context::ClientContext::current(); + let mut context3 = context::ClientContext::current(); let req1 = client.add(&mut context1, 1, 2); let req2 = client.add(&mut context2, 3, 4); @@ -234,8 +234,8 @@ async fn concurrent_join_all() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let mut context1 = context::current(); - let mut context2 = context::current(); + let mut context1 = context::ClientContext::current(); + let mut context2 = context::ClientContext::current(); let req1 = client.add(&mut context1, 1, 2); let req2 = client.add(&mut context2, 3, 4); @@ -257,7 +257,7 @@ async fn counter() -> anyhow::Result<()> { struct CountService(u32); impl Counter for &mut CountService { - async fn count(self, _: &mut context::Context) -> u32 { + async fn count(self, _: &mut context::ServerContext) -> u32 { self.0 += 1; self.0 } @@ -274,8 +274,8 @@ async fn counter() -> anyhow::Result<()> { }); let client = CounterClient::new(client::Config::default(), tx).spawn(); - assert_matches!(client.count(&mut context::current()).await, Ok(1)); - assert_matches!(client.count(&mut context::current()).await, Ok(2)); + assert_matches!(client.count(&mut context::ClientContext::current()).await, Ok(1)); + assert_matches!(client.count(&mut context::ClientContext::current()).await, Ok(2)); Ok(()) } From d1afa2cbf7d3db88a9550186e3b6ab874af892ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Sun, 23 Nov 2025 20:37:23 +0100 Subject: [PATCH 05/26] allow transports to see and manipulate client and server contexts. --- example-service/src/client.rs | 7 +- example-service/src/server.rs | 12 ++-- plugins/Cargo.toml | 1 + plugins/src/lib.rs | 10 ++- tarpc/examples/compression.rs | 17 ++--- tarpc/examples/custom_transport.rs | 7 +- tarpc/examples/pubsub.rs | 14 ++-- tarpc/examples/readme.rs | 11 +-- tarpc/examples/tls_over_tcp.rs | 5 +- tarpc/examples/tracing.rs | 16 +++-- tarpc/src/client.rs | 45 ++++++------ tarpc/src/client/in_flight_requests.rs | 17 +++-- tarpc/src/context.rs | 5 +- tarpc/src/lib.rs | 43 +++++++++--- tarpc/src/server.rs | 94 +++++++++++++++----------- tarpc/src/server/incoming.rs | 7 +- tarpc/src/server/testing.rs | 6 +- tarpc/src/transport/channel.rs | 43 ++++++++---- tarpc/tests/dataservice.rs | 7 +- tarpc/tests/service_functional.rs | 53 +++++++++++---- 20 files changed, 268 insertions(+), 152 deletions(-) diff --git a/example-service/src/client.rs b/example-service/src/client.rs index dc7104bfd..2984ae49c 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -7,7 +7,8 @@ use clap::Parser; use service::{WorldClient, init_tracing}; use std::{net::SocketAddr, time::Duration}; -use tarpc::{client, context, tokio_serde::formats::Json}; +use futures::{future, SinkExt}; +use tarpc::{client, tokio_serde::formats::Json}; use tokio::time::sleep; use tracing::Instrument; use tarpc::context::ClientContext; @@ -30,9 +31,11 @@ async fn main() -> anyhow::Result<()> { let mut transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default); transport.config_mut().max_frame_length(usize::MAX); + let transport = transport.await?.with(|msg: tarpc::ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); + // WorldClient is generated by the service attribute. It has a constructor `new` that takes a // config and any Transport as input. - let client = WorldClient::new(client::Config::default(), transport.await?).spawn(); + let client = WorldClient::new(client::Config::default(), transport).spawn(); let hello = async move { let mut context = ClientContext::current(); diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 0845783c7..00b3eb1fb 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -15,12 +15,9 @@ use std::{ net::{IpAddr, Ipv6Addr, SocketAddr}, time::Duration, }; -use tarpc::{ - context, - server::{self, Channel, incoming::Incoming}, - tokio_serde::formats::Json, -}; +use tarpc::{context, server::{self, Channel, incoming::Incoming}, tokio_serde::formats::Json, ClientMessage}; use tokio::time; +use tarpc::context::{ServerContext, SharedContext}; #[derive(Parser)] struct Flags { @@ -62,13 +59,14 @@ async fn main() -> anyhow::Result<()> { listener // Ignore accept errors. .filter_map(|r| future::ready(r.ok())) + .map(|t| t.map_ok(|msg: ClientMessage| msg.map_context(|ctx| ServerContext::new(ctx)))) .map(server::BaseChannel::with_defaults) // Limit channels to 1 per IP. - .max_channels_per_key(1, |t| t.transport().peer_addr().unwrap().ip()) + .max_channels_per_key(1, |t| t.transport().get_ref().peer_addr().unwrap().ip()) // serve is generated by the service attribute. It takes as input any type implementing // the generated World trait. .map(|channel| { - let server = HelloServer(channel.transport().peer_addr().unwrap()); + let server = HelloServer(channel.transport().get_ref().peer_addr().unwrap()); channel.execute(server.serve()).for_each(spawn) }) // Max 10 channels. diff --git a/plugins/Cargo.toml b/plugins/Cargo.toml index 8be746c26..eeab84924 100644 --- a/plugins/Cargo.toml +++ b/plugins/Cargo.toml @@ -30,5 +30,6 @@ proc-macro = true [dev-dependencies] assert-type-eq = "0.1.0" futures = "0.3" +futures-util = "0.3.31" serde = { version = "1.0", features = ["derive"] } tarpc = { path = "../tarpc", features = ["serde1"] } diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 886b85b48..bc52cf849 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -376,6 +376,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// /// ```no_run /// use tarpc::{client, transport, service, server::{self, Channel}, context::ServerContext}; +/// use futures_util::{TryStreamExt, sink::SinkExt}; /// /// #[service] /// pub trait Calculator { @@ -394,6 +395,13 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// // This could be any transport. /// let (client_side, server_side) = transport::channel::unbounded(); /// +/// let client_side = client_side.with(|msg: tarpc::ClientMessage| async move { +/// Ok(msg.map_context(|ctx| ctx.shared_context)) +/// }); +/// let server_side = server_side.map_ok(|msg: tarpc::ClientMessage| +/// msg.map_context(tarpc::context::ServerContext::new) +/// ); +/// /// // A client can be made like so: /// let client = CalculatorClient::new(client::Config::default(), client_side); /// @@ -738,7 +746,7 @@ impl ServiceGenerator<'_> { ::tarpc::client::RequestDispatch<#request_ident, #response_ident, T> > where - T: ::tarpc::Transport<::tarpc::ClientMessage<#request_ident>, ::tarpc::Response<#response_ident>> + T: ::tarpc::Transport<::tarpc::ClientMessage<::tarpc::context::ClientContext, #request_ident>, ::tarpc::Response<#response_ident>> { let new_client = ::tarpc::client::new(config, transport); ::tarpc::client::NewClient { diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index 663236731..c8c13d1db 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -9,12 +9,8 @@ use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt, prelude::*}; use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; use std::{io, io::Read, io::Write}; -use tarpc::{ - client, context, - serde_transport::tcp, - server::{BaseChannel, Channel}, - tokio_serde::formats::Bincode, -}; +use tarpc::{client, context, serde_transport::tcp, server::{BaseChannel, Channel}, tokio_serde::formats::Bincode, ClientMessage}; +use tarpc::context::{ClientContext, ServerContext, SharedContext}; /// Type of compression that should be enabled on the request. The transport is free to ignore this. #[derive(Debug, PartialEq, Eq, Clone, Copy, Deserialize, Serialize)] @@ -120,17 +116,22 @@ async fn spawn(fut: impl Future + Send + 'static) { #[tokio::main] async fn main() -> anyhow::Result<()> { let mut incoming = tcp::listen("localhost:0", Bincode::default).await?; + let addr = incoming.local_addr(); tokio::spawn(async move { let transport = incoming.next().await.unwrap().unwrap(); - BaseChannel::with_defaults(add_compression(transport)) + let transport = add_compression(transport); + let transport = transport.map_ok(|msg: ClientMessage| msg.map_context(|ctx| ServerContext::new(ctx))); + BaseChannel::with_defaults(transport) .execute(HelloServer.serve()) .for_each(spawn) .await; }); let transport = tcp::connect(addr, Bincode::default).await?; - let client = WorldClient::new(client::Config::default(), add_compression(transport)).spawn(); + let transport = add_compression(transport); + let transport = transport.with(|msg: ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); + let client = WorldClient::new(client::Config::default(), transport).spawn(); println!( "{}", diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index 1c682173d..6abf78a58 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -5,8 +5,8 @@ // https://opensource.org/licenses/MIT. use futures::prelude::*; -use tarpc::context::{ClientContext, ServerContext}; -use tarpc::serde_transport as transport; +use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::{serde_transport as transport, ClientMessage}; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; @@ -23,7 +23,6 @@ struct Service; impl PingService for Service { async fn ping(self, _: &mut ServerContext) {} } - #[tokio::main] async fn main() -> anyhow::Result<()> { let bind_addr = "/tmp/tarpc_on_unix_example.sock"; @@ -40,6 +39,7 @@ async fn main() -> anyhow::Result<()> { let (conn, _addr) = listener.accept().await.unwrap(); let framed = codec_builder.new_framed(conn); let transport = transport::new(framed, Bincode::default()); + let transport = transport.map_ok(|c: ClientMessage| c.map_context(ServerContext::new)); let fut = BaseChannel::with_defaults(transport) .execute(Service.serve()) @@ -50,6 +50,7 @@ async fn main() -> anyhow::Result<()> { let conn = UnixStream::connect(bind_addr).await?; let transport = transport::new(codec_builder.new_framed(conn), Bincode::default()); + let transport = transport.with(|msg: ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); PingServiceClient::new(Default::default(), transport) .spawn() .ping(&mut ClientContext::current()) diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 83c1371b9..bf95a2e15 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -48,15 +48,11 @@ use std::{ sync::{Arc, Mutex, RwLock}, }; use subscriber::Subscriber as _; -use tarpc::{ - client, context, - serde_transport::tcp, - server::{self, Channel}, - tokio_serde::formats::Json, -}; +use tarpc::{client, context, serde_transport::tcp, server::{self, Channel}, tokio_serde::formats::Json, ClientMessage}; use tokio::net::ToSocketAddrs; use tracing::info; use tracing_subscriber::prelude::*; +use tarpc::context::{ServerContext, SharedContext}; pub mod subscriber { #[tarpc::service] @@ -104,6 +100,7 @@ impl Subscriber { ) -> anyhow::Result { let publisher = tcp::connect(publisher_addr, Json::default).await?; let local_addr = publisher.local_addr()?; + let publisher = publisher.map_ok(|msg: ClientMessage| msg.map_context(|ctx| ServerContext::new(ctx))); let mut handler = server::BaseChannel::with_defaults(publisher).requests(); let subscriber = Subscriber { local_addr, topics }; // The first request is for the topics being subscribed to. @@ -164,6 +161,8 @@ impl Publisher { let publisher = connecting_publishers.next().await.unwrap().unwrap(); info!(publisher.peer_addr = ?publisher.peer_addr(), "publisher connected."); + let publisher = publisher.map_ok(|msg: ClientMessage| msg.map_context(|ctx| ServerContext::new(ctx))); + server::BaseChannel::with_defaults(publisher) .execute(self.serve()) .for_each(spawn) @@ -183,6 +182,7 @@ impl Publisher { tokio::spawn(async move { while let Some(conn) = connecting_subscribers.next().await { let subscriber_addr = conn.peer_addr().unwrap(); + let conn = conn.with(|msg: tarpc::ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); let tarpc::client::NewClient { client: subscriber, @@ -341,7 +341,7 @@ async fn main() -> anyhow::Result<()> { let publisher = publisher::PublisherClient::new( client::Config::default(), - tcp::connect(addrs.publisher, Json::default).await?, + tcp::connect(addrs.publisher, Json::default).await?.with(|msg: tarpc::ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))) ) .spawn(); diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index 60daf4e45..884e298f3 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -5,10 +5,8 @@ // https://opensource.org/licenses/MIT. use futures::prelude::*; -use tarpc::{ - client, context, - server::{self, Channel}, -}; +use tarpc::{client, context, server::{self, Channel}, transport, ClientMessage}; +use tarpc::context::{ClientContext, ServerContext, SharedContext}; /// This is the service definition. It looks a lot like a trait definition. /// It defines one RPC, hello, which takes one arg, name, and returns a String. @@ -34,7 +32,10 @@ async fn spawn(fut: impl Future + Send + 'static) { #[tokio::main] async fn main() -> anyhow::Result<()> { - let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); + let (client_transport, server_transport) = transport::channel::unbounded_mapped( + |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), + |msg: ClientMessage| msg.map_context(ServerContext::new), + ); let server = server::BaseChannel::with_defaults(server_transport); tokio::spawn(server.execute(HelloServer.serve()).for_each(spawn)); diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index cc3c1690b..e7307b98d 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -17,8 +17,7 @@ use tokio_rustls::rustls::{ server::{WebPkiClientVerifier, danger::ClientCertVerifier}, }; use tokio_rustls::{TlsAcceptor, TlsConnector}; - -use tarpc::context::{ClientContext, ServerContext}; +use tarpc::context::{ClientContext, ServerContext, SharedContext}; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; @@ -115,6 +114,7 @@ async fn main() -> anyhow::Result<()> { let framed = codec_builder.new_framed(tls_stream); let transport = transport::new(framed, Bincode::default()); + let transport = transport.map_ok(|c: tarpc::ClientMessage| c.map_context(|ctx| ServerContext::new(ctx))); let fut = BaseChannel::with_defaults(transport) .execute(Service.serve()) @@ -144,6 +144,7 @@ async fn main() -> anyhow::Result<()> { let stream = connector.connect(domain, stream).await?; let transport = transport::new(codec_builder.new_framed(stream), Bincode::default()); + let transport = transport.with(|msg: tarpc::ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); let answer = PingServiceClient::new(Default::default(), transport) .spawn() .ping(&mut ClientContext::current()) diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index be1b539c1..66a92738d 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -35,6 +35,7 @@ use tarpc::{ }; use tokio::net::TcpStream; use tracing_subscriber::prelude::*; +use tarpc::context::{ClientContext, ServerContext, SharedContext}; pub mod add { #[tarpc::service] @@ -124,7 +125,7 @@ where } fn make_stub( - backends: [impl Transport>, Response> + Send + Sync + 'static; N], + backends: [impl Transport>, Response> + Send + Sync + 'static; N], ) -> retry::Retry< impl Fn(&Result, u32) -> bool + Clone, load_balance::RoundRobin, Resp>>, @@ -173,23 +174,28 @@ async fn main() -> anyhow::Result<()> { .serving(AddServer.serve()); let add_server = add_listener1 .chain(add_listener2) + .map(|t| t.map_ok(|msg: ClientMessage| msg.map_context(|ctx| ServerContext::new(ctx)))) .map(BaseChannel::with_defaults); tokio::spawn(spawn_incoming(add_server.execute(server))); + let map_context = |msg: ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context)); + let add_client = add::AddClient::from(make_stub([ - tarpc::serde_transport::tcp::connect(addr1, Json::default).await?, - tarpc::serde_transport::tcp::connect(addr2, Json::default).await?, + tarpc::serde_transport::tcp::connect(addr1, Json::default).await?.with(map_context), + tarpc::serde_transport::tcp::connect(addr2, Json::default).await?.with(map_context), ])); let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) .await? - .filter_map(|r| future::ready(r.ok())); - let addr = double_listener.get_ref().local_addr(); + .filter_map(|r| future::ready(r.ok())) + .map(|t| t.map_ok(|msg: ClientMessage| msg.map_context(|ctx| ServerContext::new(ctx)))); + let addr = double_listener.get_ref().get_ref().local_addr(); let double_server = double_listener.map(BaseChannel::with_defaults).take(1); let server = DoubleServer { add_client }.serve(); tokio::spawn(spawn_incoming(double_server.execute(server))); let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; + let to_double_server = to_double_server.with(|msg: ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); let double_client = double::DoubleClient::new(client::Config::default(), to_double_server).spawn(); diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 8d3b9f4a7..f2cf73e24 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -31,6 +31,7 @@ use std::{ }; use tokio::sync::{mpsc, oneshot}; use tracing::Span; +use crate::context::ClientContext; /// Settings that control the behavior of the client. #[derive(Clone, Debug)] @@ -128,7 +129,7 @@ where otel.kind = "client", otel.name = %request.name()) )] - pub async fn call(&self, ctx: &mut context::SharedContext, request: Req) -> Result { + pub async fn call(&self, ctx: &mut context::ClientContext, request: Req) -> Result { let span = Span::current(); ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { tracing::trace!( @@ -153,7 +154,7 @@ where }; self.to_dispatch .send(DispatchRequest { - ctx: ctx.clone(), + ctx: ctx.shared_context.clone(), span, request_id, request, @@ -239,7 +240,7 @@ pub fn new( transport: C, ) -> NewClient, RequestDispatch> where - C: Transport, Response>, + C: Transport, Response>, { let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer); let (cancellation, canceled_requests) = cancellations(); @@ -287,7 +288,7 @@ pub struct RequestDispatch { impl RequestDispatch where - C: Transport, Response>, + C: Transport, Response>, { fn in_flight_requests<'a>( self: &'a mut Pin<&mut Self>, @@ -308,7 +309,7 @@ where .map_err(|e| ChannelError::Ready(Arc::new(e))) } - fn start_send(self: &mut Pin<&mut Self>, message: ClientMessage) -> Result<(), C::Error> { + fn start_send(self: &mut Pin<&mut Self>, message: ClientMessage) -> Result<(), C::Error> { self.transport_pin_mut().start_send(message) } @@ -457,7 +458,7 @@ where fn poll_next_cancellation( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>>> { + ) -> Poll>>> { ready!(self.ensure_writeable(cx)?); loop { @@ -510,18 +511,20 @@ where // poll_next_request only returns Ready if there is room to buffer another request. // Therefore, we can call write_request without fear of erroring due to a full // buffer. + + let trace_context = ctx.trace_context; + let deadline = ctx.deadline; + + let client_context = context::ClientContext::new(ctx); + let request = ClientMessage::Request(Request { id: request_id, message: request, - context: ctx.clone(), + context: client_context, }); - //TODO: Feels like we could avoid either saving the request context in insert_request - // or submitting the context in start_request. - let full_context = context::ClientContext::new(ctx); - self.in_flight_requests() - .insert_request(request_id, full_context, span.clone(), response_completion) + .insert_request(request_id, trace_context, deadline, span.clone(), response_completion) .expect("Request IDs should be unique"); match self.start_send(request) { Ok(()) => tracing::debug!("SendRequest"), @@ -543,14 +546,14 @@ where self: &mut Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>>> { - let (context, span, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) { + let (trace_context, span, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) { Some(triple) => triple, None => return Poll::Ready(None), }; let _entered = span.enter(); let cancel = ClientMessage::Cancel { - trace_context: context.trace_context, + trace_context, request_id, }; self.start_send(cancel) @@ -640,7 +643,7 @@ where impl Future for RequestDispatch where - C: Transport, Response>, + C: Transport, Response>, { type Output = Result<(), ChannelError>; @@ -710,13 +713,15 @@ mod tests { #[tokio::test] async fn response_completes_request_future() { - let (mut dispatch, mut _channel, mut server_channel) = set_up(); + let (mut dispatch, _channel, mut server_channel) = set_up(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); + let context = ClientContext::current(); + dispatch .in_flight_requests - .insert_request(0, ClientContext::current(), Span::current(), tx) + .insert_request(0, context.trace_context, context.deadline, Span::current(), tx) .unwrap(); server_channel .send(Response { @@ -1052,12 +1057,12 @@ mod tests { RequestDispatch< String, String, - UnboundedChannel, ClientMessage>, + UnboundedChannel, ClientMessage>, >, >, >, Channel, - UnboundedChannel, Response>, + UnboundedChannel, Response>, ) { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); @@ -1135,7 +1140,7 @@ mod tests { } async fn send_response( - channel: &mut UnboundedChannel, Response>, + channel: &mut UnboundedChannel, Response>, response: Response, ) { channel.send(response).await.unwrap(); diff --git a/tarpc/src/client/in_flight_requests.rs b/tarpc/src/client/in_flight_requests.rs index a368a5a48..0ffb50c63 100644 --- a/tarpc/src/client/in_flight_requests.rs +++ b/tarpc/src/client/in_flight_requests.rs @@ -1,15 +1,13 @@ -use crate::{ - context, - util::{Compact, TimeUntil}, -}; +use crate::{trace, util::{Compact, TimeUntil}}; use fnv::FnvHashMap; use std::{ collections::hash_map, task::{Context, Poll}, }; +use std::time::Instant; use tokio::sync::oneshot; use tokio_util::time::delay_queue::{self, DelayQueue}; -use tracing::Span; +use tracing::{Span}; /// Requests already written to the wire that haven't yet received responses. #[derive(Debug)] @@ -29,7 +27,7 @@ impl Default for InFlightRequests { #[derive(Debug)] struct RequestData { - ctx: context::ClientContext, + ctx: trace::Context, span: Span, response_completion: oneshot::Sender, /// The key to remove the timer for the request's deadline. @@ -56,13 +54,14 @@ impl InFlightRequests { pub fn insert_request( &mut self, request_id: u64, - ctx: context::ClientContext, + ctx: trace::Context, + deadline: Instant, span: Span, response_completion: oneshot::Sender, ) -> Result<(), AlreadyExistsError> { match self.request_data.entry(request_id) { hash_map::Entry::Vacant(vacant) => { - let timeout = ctx.deadline.time_until(); + let timeout = deadline.time_until(); let deadline_key = self.deadlines.insert(request_id, timeout); vacant.insert(RequestData { ctx, @@ -106,7 +105,7 @@ impl InFlightRequests { /// Cancels a request without completing (typically used when a request handle was dropped /// before the request completed). - pub fn cancel_request(&mut self, request_id: u64) -> Option<(context::ClientContext, Span)> { + pub fn cancel_request(&mut self, request_id: u64) -> Option<(trace::Context, Span)> { if let Some(request_data) = self.request_data.remove(&request_id) { self.request_data.compact(0.1); self.deadlines.remove(&request_data.deadline_key); diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index a96c49095..e72ab130f 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -10,10 +10,7 @@ use crate::trace::{self, TraceId}; use opentelemetry::trace::TraceContextExt; use static_assertions::assert_impl_all; -use std::{ - convert::TryFrom, - time::{Duration, Instant}, -}; +use std::{convert::TryFrom, time::{Duration, Instant}}; use std::ops::{Deref, DerefMut}; use tracing_opentelemetry::OpenTelemetrySpanExt; diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index a83efae02..c097372bc 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -142,7 +142,10 @@ //! # prelude::*, //! # }; //! # use tarpc::{ +//! # ClientMessage, //! # client, context, +//! # context::{ClientContext, ServerContext, SharedContext}, +//! # transport::channel, //! # server::{self, Channel}, //! # }; //! # // This is the service definition. It looks a lot like a trait definition. @@ -167,7 +170,10 @@ //! # #[cfg(feature = "tokio1")] //! #[tokio::main] //! async fn main() -> anyhow::Result<()> { -//! let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); +//! let (client_transport, server_transport) = channel::unbounded_mapped( +//! |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), +//! |msg: ClientMessage| msg.map_context(ServerContext::new), +//! ); //! //! let server = server::BaseChannel::with_defaults(server_transport); //! tokio::spawn( @@ -198,7 +204,7 @@ //! Use `cargo doc` as you normally would to see the documentation created for all //! items expanded by a `service!` invocation. -#![deny(missing_docs)] +#![deny(missing_docs, warnings, unused, dead_code)] #![allow(clippy::type_complexity)] #![cfg_attr(docsrs, feature(doc_cfg))] @@ -252,16 +258,18 @@ pub(crate) mod util; pub use crate::transport::sealed::Transport; use std::{any::Any, error::Error, io, sync::Arc, time::Instant}; +use std::ops::Deref; +use crate::context::{SharedContext}; /// A message from a client to a server. #[derive(Debug)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] #[non_exhaustive] -pub enum ClientMessage { +pub enum ClientMessage { /// A request initiated by a user. The server responds to a request by invoking a /// service-provided request handler. The handler completes with a [`response`](Response), which /// the server sends back to the client. - Request(Request), + Request(Request), /// A command to cancel an in-flight request, automatically sent by the client when a response /// future is dropped. /// @@ -279,16 +287,35 @@ pub enum ClientMessage { }, } +impl ClientMessage { + /// Creates a new ClientMessage by mapping the context using the provided function. + pub fn map_context(self, f: F) -> ClientMessage where F: FnOnce(Ctx) -> Ctx2 { + match self { + ClientMessage::Request(Request { context, id, message }) => { + ClientMessage::Request(Request { + context: f(context), + id, + message, + }) + } + ClientMessage::Cancel { trace_context, request_id } => { + ClientMessage::Cancel { trace_context, request_id } + } + } + } +} + + /// A request from a client to a server. #[derive(Debug)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Request { +pub struct Request { /// Trace context, deadline, and other cross-cutting concerns. - pub context: context::SharedContext, + pub context: Ctx, /// Uniquely identifies the request across all requests sent over a single channel. pub id: u64, /// The request body. - pub message: T, + pub message: Req, } /// Implemented by the request types generated by tarpc::service. @@ -491,7 +518,7 @@ impl ServerError { } } -impl Request { +impl Request where Ctx: Deref { /// Returns the deadline for this request. pub fn deadline(&self) -> &Instant { &self.context.deadline diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 3b01d207d..34efc1be6 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -60,7 +60,7 @@ impl Config { /// Returns a channel backed by `transport` and configured with `self`. pub fn channel(self, transport: T) -> BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, { BaseChannel::new(self, transport) } @@ -154,7 +154,7 @@ pub struct BaseChannel { impl BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, { /// Creates a new channel backed by `transport` and configured with `config`. pub fn new(config: Config, transport: T) -> Self { @@ -200,7 +200,7 @@ where fn start_request( mut self: Pin<&mut Self>, - mut request: Request, + mut request: Request, ) -> Result, AlreadyExistsError> { let span = info_span!( "RPC", @@ -256,7 +256,7 @@ impl fmt::Debug for BaseChannel { #[derive(Debug)] pub struct TrackedRequest { /// The request sent by the client. - pub request: Request, + pub request: Request, /// A registration to abort a future when the [`Channel`] that produced this request stops /// tracking it. pub abort_registration: AbortRegistration, @@ -341,7 +341,9 @@ where /// /// ```rust /// use tarpc::{ + /// ClientMessage, /// context, + /// context::{ClientContext, SharedContext, ServerContext}, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, serve}, /// transport, @@ -350,7 +352,10 @@ where /// /// #[tokio::main] /// async fn main() { - /// let (tx, rx) = transport::channel::unbounded(); + /// let (tx, rx) = transport::channel::unbounded_mapped( + /// |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), + /// |msg: ClientMessage| msg.map_context(ServerContext::new), + /// ); /// let server = BaseChannel::new(server::Config::default(), rx); /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); /// tokio::spawn(dispatch); @@ -385,7 +390,7 @@ where /// # Example /// /// ```rust - /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport}; + /// use tarpc::{ClientMessage, context, client, server::{self, BaseChannel, Channel, serve}, transport, context::{ClientContext, SharedContext, ServerContext}}; /// use futures::prelude::*; /// use tracing_subscriber::prelude::*; /// @@ -394,7 +399,10 @@ where /// # #[cfg(feature = "tokio1")] /// #[tokio::main] /// async fn main() { - /// let (tx, rx) = transport::channel::unbounded(); + /// let (tx, rx) = transport::channel::unbounded_mapped( + /// |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), + /// |msg: ClientMessage| msg.map_context(ServerContext::new), + /// ); /// let client = client::new(client::Config::default(), tx).spawn(); /// let channel = BaseChannel::with_defaults(rx); /// tokio::spawn( @@ -420,7 +428,7 @@ where impl Stream for BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, { type Item = Result, ChannelError>; @@ -527,7 +535,7 @@ where impl Sink> for BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, T::Error: Error, { type Error = ChannelError; @@ -580,7 +588,7 @@ impl AsRef for BaseChannel { impl Channel for BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, { type Req = Req; type Resp = Resp; @@ -736,7 +744,8 @@ where /// # Example /// /// ```rust - /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport}; + /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport, ClientMessage}; + /// use tarpc::context::{ClientContext, SharedContext, ServerContext}; /// use futures::prelude::*; /// /// # #[cfg(not(feature = "tokio1"))] @@ -744,7 +753,11 @@ where /// # #[cfg(feature = "tokio1")] /// #[tokio::main] /// async fn main() { - /// let (tx, rx) = transport::channel::unbounded(); + /// let (tx, rx) = transport::channel::unbounded_mapped( + /// |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), + /// |msg: ClientMessage| msg.map_context(ServerContext::new), + /// ); + /// /// let requests = BaseChannel::new(server::Config::default(), rx).requests(); /// let client = client::new(client::Config::default(), tx).spawn(); /// tokio::spawn( @@ -807,7 +820,7 @@ impl Drop for ResponseGuard { /// be sent to the Channel to clean up associated request state. #[derive(Debug)] pub struct InFlightRequest { - request: Request, + request: Request, abort_registration: AbortRegistration, response_guard: ResponseGuard, span: Span, @@ -816,7 +829,7 @@ pub struct InFlightRequest { impl InFlightRequest { /// Returns a reference to the request. - pub fn get(&self) -> &Request { + pub fn get(&self) -> &Request { &self.request } @@ -839,7 +852,9 @@ impl InFlightRequest { /// /// ```rust /// use tarpc::{ + /// ClientMessage, /// context, + /// context::{ClientContext, SharedContext, ServerContext}, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, serve}, /// transport, @@ -848,7 +863,10 @@ impl InFlightRequest { /// /// #[tokio::main] /// async fn main() { - /// let (tx, rx) = transport::channel::unbounded(); + /// let (tx, rx) = transport::channel::unbounded_mapped( + /// |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), + /// |msg: ClientMessage| msg.map_context(ServerContext::new), + /// ); /// let server = BaseChannel::new(server::Config::default(), rx); /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); /// tokio::spawn(dispatch); @@ -876,7 +894,7 @@ impl InFlightRequest { span, request: Request { - context, + mut context, message, id: request_id, }, @@ -885,7 +903,7 @@ impl InFlightRequest { let mut full_context = context::ServerContext::new(context); let _ = Abortable::new( async move { - let message = serve.serve(&mut full_context, message).await; + let message = serve.serve(&mut context, message).await; tracing::debug!("CompleteRequest"); let response = Response { request_id, @@ -979,11 +997,11 @@ mod tests { task::Poll, time::{Duration, Instant}, }; - use tracing_subscriber::filter::FilterExt; + use crate::context::ServerContext; fn test_channel() -> ( - Pin, Response>>>>, - UnboundedChannel, ClientMessage>, + Pin, Response>>>>, + UnboundedChannel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::unbounded(); (Box::pin(BaseChannel::new(Config::default(), rx)), tx) @@ -993,11 +1011,11 @@ mod tests { Pin< Box< Requests< - BaseChannel, Response>>, + BaseChannel, Response>>, >, >, >, - UnboundedChannel, ClientMessage>, + UnboundedChannel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::unbounded(); ( @@ -1012,11 +1030,11 @@ mod tests { Pin< Box< Requests< - BaseChannel, Response>>, + BaseChannel, Response>>, >, >, >, - channel::Channel, ClientMessage>, + channel::Channel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::bounded(capacity); // Add 1 because capacity 0 is not supported (but is supported by transport::channel::bounded). @@ -1026,9 +1044,9 @@ mod tests { (Box::pin(BaseChannel::new(config, rx).requests()), tx) } - fn fake_request(req: Req) -> ClientMessage { + fn fake_request(req: Req) -> ClientMessage { ClientMessage::Request(Request { - context: context::SharedContext::current(), + context: context::ServerContext::current(), id: 0, message: req, }) @@ -1131,14 +1149,14 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::ServerContext::current(), message: (), }) .unwrap(); assert_matches!( channel.as_mut().start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::ServerContext::current(), message: () }), Err(AlreadyExistsError) @@ -1154,7 +1172,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::ServerContext::current(), message: (), }) .unwrap(); @@ -1162,7 +1180,7 @@ mod tests { .as_mut() .start_request(Request { id: 1, - context: context::SharedContext::current(), + context: context::ServerContext::current(), message: (), }) .unwrap(); @@ -1185,7 +1203,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::ServerContext::current(), message: (), }) .unwrap(); @@ -1214,7 +1232,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::ServerContext::current(), message: (), }) .unwrap(); @@ -1256,7 +1274,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::ServerContext::current(), message: (), }) .unwrap(); @@ -1279,7 +1297,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::ServerContext::current(), message: (), }) .unwrap(); @@ -1344,7 +1362,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::ServerContext::current(), message: (), }) .unwrap(); @@ -1374,7 +1392,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 1, - context: context::SharedContext::current(), + context: context::ServerContext::current(), message: (), }) .unwrap(); @@ -1395,7 +1413,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::ServerContext::current(), message: (), }) .unwrap(); @@ -1414,7 +1432,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 1, - context: context::SharedContext::current(), + context: context::ServerContext::current(), message: (), }) .unwrap(); diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index cb01021f5..ad91f0c19 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -48,7 +48,9 @@ where /// # Example /// ```rust /// use tarpc::{ +/// ClientMessage, /// context, +/// context::{ClientContext, ServerContext, SharedContext}, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, incoming::{Incoming, spawn_incoming}, serve}, /// transport, @@ -57,7 +59,10 @@ where /// /// #[tokio::main] /// async fn main() { -/// let (tx, rx) = transport::channel::unbounded(); +/// let (tx, rx) = transport::channel::unbounded_mapped( +/// |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), +/// |msg: ClientMessage| msg.map_context(ServerContext::new), +/// ); /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); /// tokio::spawn(dispatch); /// diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index 70c4e7f69..ac2201933 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -13,7 +13,7 @@ use crate::{ use futures::{Sink, Stream, task::*}; use pin_project::pin_project; use std::{collections::VecDeque, io, pin::Pin, time::Instant}; -use tracing::Span; +use tracing::{Span}; #[pin_project] pub(crate) struct FakeChannel { @@ -92,10 +92,10 @@ impl FakeChannel>, Response> { let (request_cancellation, _) = cancellations(); self.stream.push_back(Ok(TrackedRequest { request: Request { - context: context::SharedContext { + context: context::ServerContext::new(context::SharedContext { deadline: Instant::now(), trace_context: Default::default(), - }, + }), id, message, }, diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 5cb897569..a319ef046 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -6,10 +6,11 @@ //! Transports backed by in-memory channels. -use futures::{Sink, Stream, task::*}; +use futures::{Sink, Stream, task::*, SinkExt, TryStreamExt}; use pin_project::pin_project; -use std::{error::Error, pin::Pin}; +use std::{error::Error, future, pin::Pin}; use tokio::sync::mpsc; +use crate::Transport; /// Errors that occur in the sending or receiving of messages over a channel. #[derive(thiserror::Error, Debug)] @@ -39,6 +40,23 @@ pub fn unbounded() -> ( ) } +/// Returns two mapped unbounded channel peers. Each [`Stream`] yields items sent through the other's +/// [`Sink`]. +pub fn unbounded_mapped(mut f: F, mut g: G) -> ( + impl Transport, + impl Transport, +) where + F: FnMut(ClientSinkItem) -> SerializedSinkItem + Send + 'static, + G: FnMut(SerializedSinkItem) -> ServerSinkItem + Send + 'static, +{ + let (client, server) = unbounded(); + + let client = client.with(move |msg: ClientSinkItem| future::ready(Ok(f(msg)))); + let server = server.map_ok(move |msg: SerializedSinkItem| g(msg)); + + (client, server) +} + /// A bi-directional channel backed by an [`UnboundedSender`](mpsc::UnboundedSender) /// and [`UnboundedReceiver`](mpsc::UnboundedReceiver). #[derive(Debug)] @@ -161,20 +179,15 @@ impl Sink for Channel { #[cfg(all(test, feature = "tokio1"))] mod tests { - use crate::{ - ServerError, - client::{self, RpcError}, - context, - server::{BaseChannel, incoming::Incoming, serve}, - transport::{ - self, - channel::{Channel, UnboundedChannel}, - }, - }; + use crate::{ServerError, client::{self, RpcError}, context, server::{BaseChannel, incoming::Incoming, serve}, transport::{ + self, + channel::{Channel, UnboundedChannel}, + }, ClientMessage}; use assert_matches::assert_matches; use futures::{prelude::*, stream}; use std::io; use tracing::trace; + use crate::context::{ClientContext, ServerContext, SharedContext}; #[test] fn ensure_is_transport() { @@ -187,7 +200,11 @@ mod tests { async fn integration() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (client_channel, server_channel) = transport::channel::unbounded(); + let (client_channel, server_channel) = transport::channel::unbounded_mapped( + |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), + |msg: ClientMessage| msg.map_context(ServerContext::new), + ); + tokio::spawn( stream::once(future::ready(server_channel)) .map(BaseChannel::with_defaults) diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index e4cbf338d..73f6656d9 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -1,10 +1,11 @@ use futures::prelude::*; -use tarpc::serde_transport; +use tarpc::{serde_transport, ClientMessage}; use tarpc::{ client, context, server::{BaseChannel, incoming::Incoming}, }; use tokio_serde::formats::Json; +use tarpc::context::{ClientContext, ServerContext, SharedContext}; #[tarpc::derive_serde] #[derive(Debug, PartialEq, Eq)] @@ -43,13 +44,15 @@ async fn test_call() -> anyhow::Result<()> { transport .take(1) .filter_map(|r| async { r.ok() }) + .map(|t| t.map_ok(|msg: ClientMessage| msg.map_context(|ctx| ServerContext::new(ctx)))) .map(BaseChannel::with_defaults) .execute(ColorServer.serve()) .map(|channel| channel.for_each(spawn)) .for_each(spawn), ); - let transport = serde_transport::tcp::connect(addr, Json::default).await?; + let transport = serde_transport::tcp::connect(addr, Json::default).await?.with(|msg: ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); + let client = ColorProtocolClient::new(client::Config::default(), transport).spawn(); let color = client diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 46ce7bd47..30e4c0743 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -4,13 +4,9 @@ use futures::{ prelude::*, }; use std::time::{Duration, Instant}; -use tarpc::{ - client::{self}, - context, - server::{BaseChannel, Channel, incoming::Incoming}, - transport::channel, -}; +use tarpc::{client::{self}, context, server::{BaseChannel, Channel, incoming::Incoming}, transport, transport::channel, ClientMessage}; use tokio::join; +use tarpc::context::{ClientContext, ServerContext, SharedContext}; #[tarpc_plugins::service] trait Service { @@ -33,7 +29,11 @@ impl Service for Server { #[tokio::test] async fn sequential() { - let (tx, rx) = tarpc::transport::channel::unbounded(); + let (tx, rx) = transport::channel::unbounded_mapped( + |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), + |msg: ClientMessage| msg.map_context(ServerContext::new), + ); + let client = client::new(client::Config::default(), tx).spawn(); let channel = BaseChannel::with_defaults(rx); tokio::spawn( @@ -66,7 +66,11 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = channel::unbounded(); + let (tx, rx) = transport::channel::unbounded_mapped( + |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), + |msg: ClientMessage| msg.map_context(ServerContext::new), + ); + // Set up a client that initiates a long-lived request. // The request will complete in error when the server drops the connection. @@ -105,6 +109,7 @@ async fn serde_tcp() -> anyhow::Result<()> { transport .take(1) .filter_map(|r| async { r.ok() }) + .map(|t| t.map_ok(|msg: tarpc::ClientMessage| msg.map_context(|ctx| tarpc::context::ServerContext::new(ctx)))) .map(BaseChannel::with_defaults) .execute(Server.serve()) .map(|channel| channel.for_each(spawn)) @@ -112,6 +117,7 @@ async fn serde_tcp() -> anyhow::Result<()> { ); let transport = serde_transport::tcp::connect(addr, Json::default).await?; + let transport = transport.with(|msg: tarpc::ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); let client = ServiceClient::new(client::Config::default(), transport).spawn(); assert_matches!(client.add(&mut context::ClientContext::current(), 1, 2).await, Ok(3)); @@ -137,6 +143,7 @@ async fn serde_uds() -> anyhow::Result<()> { transport .take(1) .filter_map(|r| async { r.ok() }) + .map(|t| t.map_ok(|msg: tarpc::ClientMessage| msg.map_context(|ctx| tarpc::context::ServerContext::new(ctx)))) .map(BaseChannel::with_defaults) .execute(Server.serve()) .map(|channel| channel.for_each(spawn)) @@ -144,6 +151,7 @@ async fn serde_uds() -> anyhow::Result<()> { ); let transport = serde_transport::unix::connect(&sock, Json::default).await?; + let transport = transport.with(|msg: tarpc::ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); let client = ServiceClient::new(client::Config::default(), transport).spawn(); // Save results using socket so we can clean the socket even if our test assertions fail @@ -160,7 +168,11 @@ async fn serde_uds() -> anyhow::Result<()> { async fn concurrent() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = channel::unbounded(); + let (tx, rx) = transport::channel::unbounded_mapped( + |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), + |msg: ClientMessage| msg.map_context(ServerContext::new), + ); + tokio::spawn( stream::once(ready(rx)) .map(BaseChannel::with_defaults) @@ -169,6 +181,7 @@ async fn concurrent() -> anyhow::Result<()> { .for_each(spawn), ); + let client = ServiceClient::new(client::Config::default(), tx).spawn(); let mut context = context::ClientContext::current(); @@ -189,7 +202,11 @@ async fn concurrent() -> anyhow::Result<()> { async fn concurrent_join() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = channel::unbounded(); + let (tx, rx) = transport::channel::unbounded_mapped( + |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), + |msg: ClientMessage| msg.map_context(ServerContext::new), + ); + tokio::spawn( stream::once(ready(rx)) .map(BaseChannel::with_defaults) @@ -225,7 +242,11 @@ async fn spawn(fut: impl Future + Send + 'static) { async fn concurrent_join_all() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = channel::unbounded(); + let (tx, rx) = transport::channel::unbounded_mapped( + |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), + |msg: ClientMessage| msg.map_context(ServerContext::new), + ); + tokio::spawn( BaseChannel::with_defaults(rx) .execute(Server.serve()) @@ -263,14 +284,18 @@ async fn counter() -> anyhow::Result<()> { } } - let (tx, rx) = channel::unbounded(); - tokio::spawn(async { + let (tx, rx) = channel::unbounded_mapped( + |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), + |msg: ClientMessage| msg.map_context(ServerContext::new), + ); + + tokio::task::spawn(async move { let mut requests = BaseChannel::with_defaults(rx).requests(); let mut counter = CountService(0); while let Some(Ok(request)) = requests.next().await { request.execute(counter.serve()).await; - } + }; }); let client = CounterClient::new(client::Config::default(), tx).spawn(); From 97d0a37bafa5788531820b653c679930420a7092 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Sun, 23 Nov 2025 21:07:22 +0100 Subject: [PATCH 06/26] run cargo fmt --- example-service/src/client.rs | 10 ++- example-service/src/server.rs | 14 +++- plugins/tests/service.rs | 7 +- tarpc/examples/compression.rs | 19 ++++- tarpc/examples/custom_transport.rs | 9 ++- tarpc/examples/pubsub.rs | 59 ++++++++++---- tarpc/examples/readme.rs | 10 ++- tarpc/examples/tls_over_tcp.rs | 18 +++-- tarpc/examples/tracing.rs | 40 +++++++--- tarpc/src/client.rs | 46 ++++++++--- tarpc/src/client/in_flight_requests.rs | 9 ++- tarpc/src/client/stub.rs | 25 ++++-- tarpc/src/client/stub/load_balance.rs | 12 ++- tarpc/src/client/stub/mock.rs | 6 +- tarpc/src/context.rs | 14 ++-- tarpc/src/lib.rs | 41 ++++++---- tarpc/src/server.rs | 72 ++++++++++++++---- tarpc/src/server/request_hook/after.rs | 12 ++- tarpc/src/server/request_hook/before.rs | 18 ++++- .../server/request_hook/before_and_after.rs | 6 +- tarpc/src/server/testing.rs | 2 +- tarpc/src/transport/channel.rs | 36 ++++++--- tarpc/tests/dataservice.rs | 16 +++- tarpc/tests/service_functional.rs | 76 +++++++++++++++---- 24 files changed, 432 insertions(+), 145 deletions(-) diff --git a/example-service/src/client.rs b/example-service/src/client.rs index 2984ae49c..e425c9eb2 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -5,13 +5,13 @@ // https://opensource.org/licenses/MIT. use clap::Parser; +use futures::{SinkExt, future}; use service::{WorldClient, init_tracing}; use std::{net::SocketAddr, time::Duration}; -use futures::{future, SinkExt}; +use tarpc::context::ClientContext; use tarpc::{client, tokio_serde::formats::Json}; use tokio::time::sleep; use tracing::Instrument; -use tarpc::context::ClientContext; #[derive(Parser)] struct Flags { @@ -31,7 +31,11 @@ async fn main() -> anyhow::Result<()> { let mut transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default); transport.config_mut().max_frame_length(usize::MAX); - let transport = transport.await?.with(|msg: tarpc::ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); + let transport = transport + .await? + .with(|msg: tarpc::ClientMessage| { + future::ok(msg.map_context(|ctx| ctx.shared_context)) + }); // WorldClient is generated by the service attribute. It has a constructor `new` that takes a // config and any Transport as input. diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 00b3eb1fb..7e29da291 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -15,9 +15,13 @@ use std::{ net::{IpAddr, Ipv6Addr, SocketAddr}, time::Duration, }; -use tarpc::{context, server::{self, Channel, incoming::Incoming}, tokio_serde::formats::Json, ClientMessage}; -use tokio::time; use tarpc::context::{ServerContext, SharedContext}; +use tarpc::{ + ClientMessage, context, + server::{self, Channel, incoming::Incoming}, + tokio_serde::formats::Json, +}; +use tokio::time; #[derive(Parser)] struct Flags { @@ -59,7 +63,11 @@ async fn main() -> anyhow::Result<()> { listener // Ignore accept errors. .filter_map(|r| future::ready(r.ok())) - .map(|t| t.map_ok(|msg: ClientMessage| msg.map_context(|ctx| ServerContext::new(ctx)))) + .map(|t| { + t.map_ok(|msg: ClientMessage| { + msg.map_context(|ctx| ServerContext::new(ctx)) + }) + }) .map(server::BaseChannel::with_defaults) // Limit channels to 1 per IP. .max_channels_per_key(1, |t| t.transport().get_ref().peer_addr().unwrap().ip()) diff --git a/plugins/tests/service.rs b/plugins/tests/service.rs index b03f3470f..756766621 100644 --- a/plugins/tests/service.rs +++ b/plugins/tests/service.rs @@ -12,7 +12,12 @@ fn att_service_trait() { } impl Foo for () { - async fn two_part(self, _: &mut context::ServerContext, s: String, i: i32) -> (String, i32) { + async fn two_part( + self, + _: &mut context::ServerContext, + s: String, + i: i32, + ) -> (String, i32) { (s, i) } diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index c8c13d1db..46300999e 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -9,8 +9,13 @@ use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt, prelude::*}; use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; use std::{io, io::Read, io::Write}; -use tarpc::{client, context, serde_transport::tcp, server::{BaseChannel, Channel}, tokio_serde::formats::Bincode, ClientMessage}; use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::{ + ClientMessage, client, context, + serde_transport::tcp, + server::{BaseChannel, Channel}, + tokio_serde::formats::Bincode, +}; /// Type of compression that should be enabled on the request. The transport is free to ignore this. #[derive(Debug, PartialEq, Eq, Clone, Copy, Deserialize, Serialize)] @@ -121,7 +126,9 @@ async fn main() -> anyhow::Result<()> { tokio::spawn(async move { let transport = incoming.next().await.unwrap().unwrap(); let transport = add_compression(transport); - let transport = transport.map_ok(|msg: ClientMessage| msg.map_context(|ctx| ServerContext::new(ctx))); + let transport = transport.map_ok(|msg: ClientMessage| { + msg.map_context(|ctx| ServerContext::new(ctx)) + }); BaseChannel::with_defaults(transport) .execute(HelloServer.serve()) .for_each(spawn) @@ -130,12 +137,16 @@ async fn main() -> anyhow::Result<()> { let transport = tcp::connect(addr, Bincode::default).await?; let transport = add_compression(transport); - let transport = transport.with(|msg: ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); + let transport = transport.with(|msg: ClientMessage| { + future::ok(msg.map_context(|ctx| ctx.shared_context)) + }); let client = WorldClient::new(client::Config::default(), transport).spawn(); println!( "{}", - client.hello(&mut context::ClientContext::current(), "friend".into()).await? + client + .hello(&mut context::ClientContext::current(), "friend".into()) + .await? ); Ok(()) } diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index 6abf78a58..a1b1e4410 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -6,10 +6,10 @@ use futures::prelude::*; use tarpc::context::{ClientContext, ServerContext, SharedContext}; -use tarpc::{serde_transport as transport, ClientMessage}; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; +use tarpc::{ClientMessage, serde_transport as transport}; use tokio::net::{UnixListener, UnixStream}; #[tarpc::service] @@ -39,7 +39,8 @@ async fn main() -> anyhow::Result<()> { let (conn, _addr) = listener.accept().await.unwrap(); let framed = codec_builder.new_framed(conn); let transport = transport::new(framed, Bincode::default()); - let transport = transport.map_ok(|c: ClientMessage| c.map_context(ServerContext::new)); + let transport = transport + .map_ok(|c: ClientMessage| c.map_context(ServerContext::new)); let fut = BaseChannel::with_defaults(transport) .execute(Service.serve()) @@ -50,7 +51,9 @@ async fn main() -> anyhow::Result<()> { let conn = UnixStream::connect(bind_addr).await?; let transport = transport::new(codec_builder.new_framed(conn), Bincode::default()); - let transport = transport.with(|msg: ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); + let transport = transport.with(|msg: ClientMessage| { + future::ok(msg.map_context(|ctx| ctx.shared_context)) + }); PingServiceClient::new(Default::default(), transport) .spawn() .ping(&mut ClientContext::current()) diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index bf95a2e15..16195ef3f 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -48,11 +48,16 @@ use std::{ sync::{Arc, Mutex, RwLock}, }; use subscriber::Subscriber as _; -use tarpc::{client, context, serde_transport::tcp, server::{self, Channel}, tokio_serde::formats::Json, ClientMessage}; +use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::{ + ClientMessage, client, context, + serde_transport::tcp, + server::{self, Channel}, + tokio_serde::formats::Json, +}; use tokio::net::ToSocketAddrs; use tracing::info; use tracing_subscriber::prelude::*; -use tarpc::context::{ServerContext, SharedContext}; pub mod subscriber { #[tarpc::service] @@ -100,7 +105,9 @@ impl Subscriber { ) -> anyhow::Result { let publisher = tcp::connect(publisher_addr, Json::default).await?; let local_addr = publisher.local_addr()?; - let publisher = publisher.map_ok(|msg: ClientMessage| msg.map_context(|ctx| ServerContext::new(ctx))); + let publisher = publisher.map_ok(|msg: ClientMessage| { + msg.map_context(|ctx| ServerContext::new(ctx)) + }); let mut handler = server::BaseChannel::with_defaults(publisher).requests(); let subscriber = Subscriber { local_addr, topics }; // The first request is for the topics being subscribed to. @@ -161,7 +168,9 @@ impl Publisher { let publisher = connecting_publishers.next().await.unwrap().unwrap(); info!(publisher.peer_addr = ?publisher.peer_addr(), "publisher connected."); - let publisher = publisher.map_ok(|msg: ClientMessage| msg.map_context(|ctx| ServerContext::new(ctx))); + let publisher = publisher.map_ok(|msg: ClientMessage| { + msg.map_context(|ctx| ServerContext::new(ctx)) + }); server::BaseChannel::with_defaults(publisher) .execute(self.serve()) @@ -182,7 +191,11 @@ impl Publisher { tokio::spawn(async move { while let Some(conn) = connecting_subscribers.next().await { let subscriber_addr = conn.peer_addr().unwrap(); - let conn = conn.with(|msg: tarpc::ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); + let conn = conn.with( + |msg: tarpc::ClientMessage| { + future::ok(msg.map_context(|ctx| ctx.shared_context)) + }, + ); let tarpc::client::NewClient { client: subscriber, @@ -210,7 +223,10 @@ impl Publisher { subscriber: subscriber::SubscriberClient, ) { // Populate the topics - if let Ok(topics) = subscriber.topics(&mut context::ClientContext::current()).await { + if let Ok(topics) = subscriber + .topics(&mut context::ClientContext::current()) + .await + { self.clients.lock().unwrap().insert( subscriber_addr, Subscription { @@ -271,10 +287,15 @@ impl publisher::Publisher for Publisher { }; let mut publications = Vec::new(); - for client in subscribers.values_mut() { publications.push(async { - client.receive(&mut context::ClientContext::current(), topic.clone(), message.clone()).await + client + .receive( + &mut context::ClientContext::current(), + topic.clone(), + message.clone(), + ) + .await }); } // Ignore failing subscribers. In a real pubsub, you'd want to continually retry until @@ -341,31 +362,43 @@ async fn main() -> anyhow::Result<()> { let publisher = publisher::PublisherClient::new( client::Config::default(), - tcp::connect(addrs.publisher, Json::default).await?.with(|msg: tarpc::ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))) + tcp::connect(addrs.publisher, Json::default).await?.with( + |msg: tarpc::ClientMessage| { + future::ok(msg.map_context(|ctx| ctx.shared_context)) + }, + ), ) .spawn(); publisher - .publish(&mut context::current(), "calculus".into(), "sqrt(2)".into()) + .publish( + &mut ClientContext::current(), + "calculus".into(), + "sqrt(2)".into(), + ) .await?; publisher .publish( - &mut context::current(), + &mut ClientContext::current(), "cool shorts".into(), "hello to all".into(), ) .await?; publisher - .publish(&mut context::current(), "history".into(), "napoleon".to_string()) + .publish( + &mut ClientContext::current(), + "history".into(), + "napoleon".to_string(), + ) .await?; drop(_subscriber0); publisher .publish( - &mut context::current(), + &mut ClientContext::current(), "cool shorts".into(), "hello to who?".into(), ) diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index 884e298f3..44ae497ff 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -5,8 +5,12 @@ // https://opensource.org/licenses/MIT. use futures::prelude::*; -use tarpc::{client, context, server::{self, Channel}, transport, ClientMessage}; use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::{ + ClientMessage, client, context, + server::{self, Channel}, + transport, +}; /// This is the service definition. It looks a lot like a trait definition. /// It defines one RPC, hello, which takes one arg, name, and returns a String. @@ -47,7 +51,9 @@ async fn main() -> anyhow::Result<()> { // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context // specifies a deadline and trace information which can be helpful in debugging requests. - let hello = client.hello(&mut context::ClientContext::current(), "Stim".to_string()).await?; + let hello = client + .hello(&mut context::ClientContext::current(), "Stim".to_string()) + .await?; println!("{hello}"); diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index e7307b98d..bac4a8048 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -10,6 +10,11 @@ use std::io::{self, BufReader, Cursor}; use std::net::{IpAddr, Ipv4Addr}; use std::sync::Arc; +use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::serde_transport as transport; +use tarpc::server::{BaseChannel, Channel}; +use tarpc::tokio_serde::formats::Bincode; +use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; use tokio::net::TcpListener; use tokio::net::TcpStream; use tokio_rustls::rustls::{ @@ -17,11 +22,6 @@ use tokio_rustls::rustls::{ server::{WebPkiClientVerifier, danger::ClientCertVerifier}, }; use tokio_rustls::{TlsAcceptor, TlsConnector}; -use tarpc::context::{ClientContext, ServerContext, SharedContext}; -use tarpc::serde_transport as transport; -use tarpc::server::{BaseChannel, Channel}; -use tarpc::tokio_serde::formats::Bincode; -use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; #[tarpc::service] pub trait PingService { @@ -114,7 +114,9 @@ async fn main() -> anyhow::Result<()> { let framed = codec_builder.new_framed(tls_stream); let transport = transport::new(framed, Bincode::default()); - let transport = transport.map_ok(|c: tarpc::ClientMessage| c.map_context(|ctx| ServerContext::new(ctx))); + let transport = transport.map_ok(|c: tarpc::ClientMessage| { + c.map_context(|ctx| ServerContext::new(ctx)) + }); let fut = BaseChannel::with_defaults(transport) .execute(Service.serve()) @@ -144,7 +146,9 @@ async fn main() -> anyhow::Result<()> { let stream = connector.connect(domain, stream).await?; let transport = transport::new(codec_builder.new_framed(stream), Bincode::default()); - let transport = transport.with(|msg: tarpc::ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); + let transport = transport.with(|msg: tarpc::ClientMessage| { + future::ok(msg.map_context(|ctx| ctx.shared_context)) + }); let answer = PingServiceClient::new(Default::default(), transport) .spawn() .ping(&mut ClientContext::current()) diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index 66a92738d..52b068bc8 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -19,6 +19,7 @@ use std::{ atomic::{AtomicBool, Ordering}, }, }; +use tarpc::context::{ClientContext, ServerContext, SharedContext}; use tarpc::{ ClientMessage, RequestName, Response, ServerError, Transport, client::{ @@ -35,7 +36,6 @@ use tarpc::{ }; use tokio::net::TcpStream; use tracing_subscriber::prelude::*; -use tarpc::context::{ClientContext, ServerContext, SharedContext}; pub mod add { #[tarpc::service] @@ -125,7 +125,8 @@ where } fn make_stub( - backends: [impl Transport>, Response> + Send + Sync + 'static; N], + backends: [impl Transport>, Response> + Send + Sync + 'static; + N], ) -> retry::Retry< impl Fn(&Result, u32) -> bool + Clone, load_balance::RoundRobin, Resp>>, @@ -174,33 +175,54 @@ async fn main() -> anyhow::Result<()> { .serving(AddServer.serve()); let add_server = add_listener1 .chain(add_listener2) - .map(|t| t.map_ok(|msg: ClientMessage| msg.map_context(|ctx| ServerContext::new(ctx)))) + .map(|t| { + t.map_ok(|msg: ClientMessage| { + msg.map_context(|ctx| ServerContext::new(ctx)) + }) + }) .map(BaseChannel::with_defaults); tokio::spawn(spawn_incoming(add_server.execute(server))); - let map_context = |msg: ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context)); + let map_context = |msg: ClientMessage| { + future::ok(msg.map_context(|ctx| ctx.shared_context)) + }; let add_client = add::AddClient::from(make_stub([ - tarpc::serde_transport::tcp::connect(addr1, Json::default).await?.with(map_context), - tarpc::serde_transport::tcp::connect(addr2, Json::default).await?.with(map_context), + tarpc::serde_transport::tcp::connect(addr1, Json::default) + .await? + .with(map_context), + tarpc::serde_transport::tcp::connect(addr2, Json::default) + .await? + .with(map_context), ])); let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) .await? .filter_map(|r| future::ready(r.ok())) - .map(|t| t.map_ok(|msg: ClientMessage| msg.map_context(|ctx| ServerContext::new(ctx)))); + .map(|t| { + t.map_ok(|msg: ClientMessage| { + msg.map_context(|ctx| ServerContext::new(ctx)) + }) + }); let addr = double_listener.get_ref().get_ref().local_addr(); let double_server = double_listener.map(BaseChannel::with_defaults).take(1); let server = DoubleServer { add_client }.serve(); tokio::spawn(spawn_incoming(double_server.execute(server))); let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; - let to_double_server = to_double_server.with(|msg: ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); + let to_double_server = to_double_server.with(|msg: ClientMessage| { + future::ok(msg.map_context(|ctx| ctx.shared_context)) + }); let double_client = double::DoubleClient::new(client::Config::default(), to_double_server).spawn(); for _ in 1..=5 { - tracing::info!("{:?}", double_client.double(&mut context::ClientContext::current(), 1).await?); + tracing::info!( + "{:?}", + double_client + .double(&mut context::ClientContext::current(), 1) + .await? + ); } tracer_provider.shutdown()?; diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index f2cf73e24..ebcb69db1 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -9,6 +9,7 @@ mod in_flight_requests; pub mod stub; +use crate::context::ClientContext; use crate::{ ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport, cancellations::{CanceledRequests, RequestCancellation, cancellations}, @@ -31,7 +32,6 @@ use std::{ }; use tokio::sync::{mpsc, oneshot}; use tracing::Span; -use crate::context::ClientContext; /// Settings that control the behavior of the client. #[derive(Clone, Debug)] @@ -129,7 +129,11 @@ where otel.kind = "client", otel.name = %request.name()) )] - pub async fn call(&self, ctx: &mut context::ClientContext, request: Req) -> Result { + pub async fn call( + &self, + ctx: &mut context::ClientContext, + request: Req, + ) -> Result { let span = Span::current(); ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { tracing::trace!( @@ -309,7 +313,10 @@ where .map_err(|e| ChannelError::Ready(Arc::new(e))) } - fn start_send(self: &mut Pin<&mut Self>, message: ClientMessage) -> Result<(), C::Error> { + fn start_send( + self: &mut Pin<&mut Self>, + message: ClientMessage, + ) -> Result<(), C::Error> { self.transport_pin_mut().start_send(message) } @@ -524,7 +531,13 @@ where }); self.in_flight_requests() - .insert_request(request_id, trace_context, deadline, span.clone(), response_completion) + .insert_request( + request_id, + trace_context, + deadline, + span.clone(), + response_completion, + ) .expect("Request IDs should be unique"); match self.start_send(request) { Ok(()) => tracing::debug!("SendRequest"), @@ -546,10 +559,11 @@ where self: &mut Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>>> { - let (trace_context, span, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) { - Some(triple) => triple, - None => return Poll::Ready(None), - }; + let (trace_context, span, request_id) = + match ready!(self.as_mut().poll_next_cancellation(cx)?) { + Some(triple) => triple, + None => return Poll::Ready(None), + }; let _entered = span.enter(); let cancel = ClientMessage::Cancel { @@ -674,7 +688,7 @@ where /// the lifecycle of the request. #[derive(Debug)] struct DispatchRequest { - pub ctx: context::Context, + pub ctx: context::SharedContext, pub span: Span, pub request_id: u64, pub request: Req, @@ -686,10 +700,10 @@ mod tests { use super::{ Channel, DispatchRequest, RequestDispatch, ResponseGuard, RpcError, cancellations, }; + use crate::context::{ClientContext, SharedContext}; use crate::{ ChannelError, ClientMessage, Response, client::{Config, in_flight_requests::InFlightRequests}, - context::{self, current}, transport::{self, channel::UnboundedChannel}, }; use assert_matches::assert_matches; @@ -721,7 +735,13 @@ mod tests { dispatch .in_flight_requests - .insert_request(0, context.trace_context, context.deadline, Span::current(), tx) + .insert_request( + 0, + context.trace_context, + context.deadline, + Span::current(), + tx, + ) .unwrap(); server_channel .send(Response { @@ -888,7 +908,9 @@ mod tests { let (dispatch, channel, _server_channel) = set_up(); drop(dispatch); // error on send - let resp = channel.call(&mut ClientContext::current(), "hi".to_string()).await; + let resp = channel + .call(&mut ClientContext::current(), "hi".to_string()) + .await; assert_matches!(resp, Err(RpcError::Shutdown)); } diff --git a/tarpc/src/client/in_flight_requests.rs b/tarpc/src/client/in_flight_requests.rs index 0ffb50c63..7a554de27 100644 --- a/tarpc/src/client/in_flight_requests.rs +++ b/tarpc/src/client/in_flight_requests.rs @@ -1,13 +1,16 @@ -use crate::{trace, util::{Compact, TimeUntil}}; +use crate::{ + trace, + util::{Compact, TimeUntil}, +}; use fnv::FnvHashMap; +use std::time::Instant; use std::{ collections::hash_map, task::{Context, Poll}, }; -use std::time::Instant; use tokio::sync::oneshot; use tokio_util::time::delay_queue::{self, DelayQueue}; -use tracing::{Span}; +use tracing::Span; /// Requests already written to the wire that haven't yet received responses. #[derive(Debug)] diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index c7dc12008..b99f8e42c 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -24,8 +24,11 @@ pub trait Stub { type Resp; /// Calls a remote service. - async fn call(&self, ctx: &mut context::ClientContext, request: Self::Req) - -> Result; + async fn call( + &self, + ctx: &mut context::ClientContext, + request: Self::Req, + ) -> Result; } impl Stub for Channel @@ -35,7 +38,11 @@ where type Req = Req; type Resp = Resp; - async fn call(&self, ctx: &mut context::ClientContext, request: Req) -> Result { + async fn call( + &self, + ctx: &mut context::ClientContext, + request: Req, + ) -> Result { Self::call(self, ctx, request).await } } @@ -46,10 +53,18 @@ where { type Req = S::Req; type Resp = S::Resp; - async fn call(&self, ctx: &mut context::ClientContext, req: Self::Req) -> Result { + async fn call( + &self, + ctx: &mut context::ClientContext, + req: Self::Req, + ) -> Result { let mut server_ctx = context::ServerContext::new(ctx.shared_context.clone()); - let res = self.clone().serve(&mut server_ctx, req).await.map_err(RpcError::Server); + let res = self + .clone() + .serve(&mut server_ctx, req) + .await + .map_err(RpcError::Server); ctx.shared_context = server_ctx.shared_context; diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs index bf70ebe2a..62c8bf677 100644 --- a/tarpc/src/client/stub/load_balance.rs +++ b/tarpc/src/client/stub/load_balance.rs @@ -200,13 +200,19 @@ mod consistent_hash { )?; for _ in 0..2 { - let resp = stub.call(&mut context::ClientContext::current(), 'a').await?; + let resp = stub + .call(&mut context::ClientContext::current(), 'a') + .await?; assert_eq!(resp, 1); - let resp = stub.call(&mut context::ClientContext::current(), 'b').await?; + let resp = stub + .call(&mut context::ClientContext::current(), 'b') + .await?; assert_eq!(resp, 2); - let resp = stub.call(&mut context::ClientContext::current(), 'c').await?; + let resp = stub + .call(&mut context::ClientContext::current(), 'c') + .await?; assert_eq!(resp, 3); } diff --git a/tarpc/src/client/stub/mock.rs b/tarpc/src/client/stub/mock.rs index 451544433..bebd8fc99 100644 --- a/tarpc/src/client/stub/mock.rs +++ b/tarpc/src/client/stub/mock.rs @@ -30,7 +30,11 @@ where type Req = Req; type Resp = Resp; - async fn call(&self, _: &mut context::ClientContext, request: Self::Req) -> Result { + async fn call( + &self, + _: &mut context::ClientContext, + request: Self::Req, + ) -> Result { self.responses .get(&request) .cloned() diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index e72ab130f..bbbc3721d 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -10,8 +10,11 @@ use crate::trace::{self, TraceId}; use opentelemetry::trace::TraceContextExt; use static_assertions::assert_impl_all; -use std::{convert::TryFrom, time::{Duration, Instant}}; use std::ops::{Deref, DerefMut}; +use std::{ + convert::TryFrom, + time::{Duration, Instant}, +}; use tracing_opentelemetry::OpenTelemetrySpanExt; /// A request context that carries request-scoped information like deadlines and trace information. @@ -51,9 +54,7 @@ pub struct ServerContext { impl ServerContext { /// Creates a new ServerContext from the given SharedContext with no extensions. pub fn new(shared_context: SharedContext) -> Self { - Self { - shared_context, - } + Self { shared_context } } /// Creates a new ServerContext for the current shared context with no extensions. @@ -85,15 +86,12 @@ impl DerefMut for ServerContext { pub struct ClientContext { /// Shared context sent from client to server which contains information used by both sides. pub shared_context: SharedContext, - } impl ClientContext { /// Creates a new ServerContext from the given SharedContext with no extensions. pub fn new(shared_context: SharedContext) -> Self { - Self { - shared_context, - } + Self { shared_context } } /// Creates a new ServerContext for the current shared context with no extensions. diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index c097372bc..cf3423eb5 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -257,9 +257,9 @@ pub(crate) mod util; pub use crate::transport::sealed::Transport; -use std::{any::Any, error::Error, io, sync::Arc, time::Instant}; +use crate::context::SharedContext; use std::ops::Deref; -use crate::context::{SharedContext}; +use std::{any::Any, error::Error, io, sync::Arc, time::Instant}; /// A message from a client to a server. #[derive(Debug)] @@ -289,23 +289,31 @@ pub enum ClientMessage { impl ClientMessage { /// Creates a new ClientMessage by mapping the context using the provided function. - pub fn map_context(self, f: F) -> ClientMessage where F: FnOnce(Ctx) -> Ctx2 { + pub fn map_context(self, f: F) -> ClientMessage + where + F: FnOnce(Ctx) -> Ctx2, + { match self { - ClientMessage::Request(Request { context, id, message }) => { - ClientMessage::Request(Request { - context: f(context), - id, - message, - }) - } - ClientMessage::Cancel { trace_context, request_id } => { - ClientMessage::Cancel { trace_context, request_id } - } + ClientMessage::Request(Request { + context, + id, + message, + }) => ClientMessage::Request(Request { + context: f(context), + id, + message, + }), + ClientMessage::Cancel { + trace_context, + request_id, + } => ClientMessage::Cancel { + trace_context, + request_id, + }, } } } - /// A request from a client to a server. #[derive(Debug)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] @@ -518,7 +526,10 @@ impl ServerError { } } -impl Request where Ctx: Deref { +impl Request +where + Ctx: Deref, +{ /// Returns the deadline for this request. pub fn deadline(&self) -> &Instant { &self.context.deadline diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 34efc1be6..559d80ba8 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -6,6 +6,7 @@ //! Provides a server that concurrently handles many connections sending multiplexed requests. +use crate::context::ServerContext; use crate::{ ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport, cancellations::{CanceledRequests, RequestCancellation, cancellations}, @@ -76,7 +77,11 @@ pub trait Serve { type Resp; /// Responds to a single request. - async fn serve(self, ctx: &mut context::ServerContext, req: Self::Req) -> Result; + async fn serve( + self, + ctx: &mut context::ServerContext, + req: Self::Req, + ) -> Result; } /// A Serve wrapper around a Fn. @@ -104,7 +109,10 @@ impl Copy for ServeFn where F: Copy {} /// Result>`. pub fn serve(f: F) -> ServeFn where - for<'a> F: FnOnce(&'a mut context::ServerContext, Req) -> Pin> + 'a + Send>>, + for<'a> F: FnOnce( + &'a mut context::ServerContext, + Req, + ) -> Pin> + 'a + Send>>, { ServeFn { f, @@ -115,7 +123,10 @@ where impl Serve for ServeFn where Req: RequestName, - for<'a> F: FnOnce(&'a mut context::ServerContext, Req) -> Pin> + 'a + Send>>, + for<'a> F: FnOnce( + &'a mut context::ServerContext, + Req, + ) -> Pin> + 'a + Send>>, { type Req = Req; type Resp = Resp; @@ -900,7 +911,6 @@ impl InFlightRequest { }, } = self; span.record("otel.name", message.name()); - let mut full_context = context::ServerContext::new(context); let _ = Abortable::new( async move { let message = serve.serve(&mut context, message).await; @@ -980,6 +990,7 @@ mod tests { request_hook::{AfterRequest, BeforeRequest, RequestHook}, serve, }; + use crate::context::ServerContext; use crate::{ ClientMessage, Request, Response, ServerError, context, trace, transport::channel::{self, UnboundedChannel}, @@ -997,10 +1008,17 @@ mod tests { task::Poll, time::{Duration, Instant}, }; - use crate::context::ServerContext; fn test_channel() -> ( - Pin, Response>>>>, + Pin< + Box< + BaseChannel< + Req, + Resp, + UnboundedChannel, Response>, + >, + >, + >, UnboundedChannel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::unbounded(); @@ -1011,7 +1029,11 @@ mod tests { Pin< Box< Requests< - BaseChannel, Response>>, + BaseChannel< + Req, + Resp, + UnboundedChannel, Response>, + >, >, >, >, @@ -1030,7 +1052,11 @@ mod tests { Pin< Box< Requests< - BaseChannel, Response>>, + BaseChannel< + Req, + Resp, + channel::Channel, Response>, + >, >, >, >, @@ -1061,7 +1087,10 @@ mod tests { #[tokio::test] async fn test_serve() { let serve = serve(|_, i| async move { Ok(i) }.boxed()); - assert_matches!(serve.serve(&mut context::ServerContext::current(), 7).await, Ok(7)); + assert_matches!( + serve.serve(&mut context::ServerContext::current(), 7).await, + Ok(7) + ); } #[tokio::test] @@ -1081,10 +1110,13 @@ mod tests { let some_time = Instant::now() + Duration::from_secs(37); let some_other_time = Instant::now() + Duration::from_secs(83); - let serve = serve(move |ctx: &mut context::ServerContext, i| async move { - assert_eq!(ctx.deadline, some_time); - Ok(i) - }.boxed()); + let serve = serve(move |ctx: &mut context::ServerContext, i| { + async move { + assert_eq!(ctx.deadline, some_time); + Ok(i) + } + .boxed() + }); let deadline_hook = serve.before(SetDeadline(some_time)); let mut ctx = context::ServerContext::current(); ctx.deadline = some_other_time; @@ -1117,7 +1149,11 @@ mod tests { } } impl AfterRequest for PrintLatency { - async fn after(&mut self, _: &mut context::ServerContext, _: &mut Result) { + async fn after( + &mut self, + _: &mut context::ServerContext, + _: &mut Result, + ) { tracing::debug!("Elapsed: {:?}", self.start.elapsed()); } } @@ -1136,7 +1172,9 @@ mod tests { let deadline_hook = serve.before(|_: &mut context::ServerContext, _: &i32| async { Err(ServerError::new(io::ErrorKind::Other, "oops".into())) }); - let resp: Result = deadline_hook.serve(&mut context::ServerContext::current(), 7).await; + let resp: Result = deadline_hook + .serve(&mut context::ServerContext::current(), 7) + .await; assert_matches!(resp, Err(_)); Ok(()) } @@ -1341,7 +1379,9 @@ mod tests { Poll::Ready(Some(Ok(request))) => request, result => panic!("Unexpected result: {result:?}"), }; - request.execute(serve(|_, _| async { Ok(()) }.boxed())).await; + request + .execute(serve(|_, _| async { Ok(()) }.boxed())) + .await; assert!( requests .as_mut() diff --git a/tarpc/src/server/request_hook/after.rs b/tarpc/src/server/request_hook/after.rs index d9e676ca4..64d65807f 100644 --- a/tarpc/src/server/request_hook/after.rs +++ b/tarpc/src/server/request_hook/after.rs @@ -15,7 +15,11 @@ pub trait AfterRequest { /// The function that is called after request execution. /// /// The hook can modify the request context and the response. - async fn after(&mut self, ctx: &mut context::ServerContext, resp: &mut Result); + async fn after( + &mut self, + ctx: &mut context::ServerContext, + resp: &mut Result, + ); } impl AfterRequest for F @@ -23,7 +27,11 @@ where F: FnMut(&mut context::ServerContext, &mut Result) -> Fut, Fut: Future, { - async fn after(&mut self, ctx: &mut context::ServerContext, resp: &mut Result) { + async fn after( + &mut self, + ctx: &mut context::ServerContext, + resp: &mut Result, + ) { self(ctx, resp).await } } diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs index 4a1b2ad8a..1f647227f 100644 --- a/tarpc/src/server/request_hook/before.rs +++ b/tarpc/src/server/request_hook/before.rs @@ -19,7 +19,11 @@ pub trait BeforeRequest { /// /// This function can also modify the request context. This could be used, for example, to /// enforce a maximum deadline on all requests. - async fn before(&mut self, ctx: &mut context::ServerContext, req: &Req) -> Result<(), ServerError>; + async fn before( + &mut self, + ctx: &mut context::ServerContext, + req: &Req, + ) -> Result<(), ServerError>; } /// A list of hooks that run in order before request execution. @@ -59,7 +63,11 @@ where F: FnMut(&mut context::ServerContext, &Req) -> Fut, Fut: Future>, { - async fn before(&mut self, ctx: &mut context::ServerContext, req: &Req) -> Result<(), ServerError> { + async fn before( + &mut self, + ctx: &mut context::ServerContext, + req: &Req, + ) -> Result<(), ServerError> { self(ctx, req).await } } @@ -141,7 +149,11 @@ pub struct BeforeRequestNil; impl, Rest: BeforeRequest> BeforeRequest for BeforeRequestCons { - async fn before(&mut self, ctx: &mut context::ServerContext, req: &Req) -> Result<(), ServerError> { + async fn before( + &mut self, + ctx: &mut context::ServerContext, + req: &Req, + ) -> Result<(), ServerError> { let BeforeRequestCons(first, rest) = self; first.before(ctx, req).await?; rest.before(ctx, req).await?; diff --git a/tarpc/src/server/request_hook/before_and_after.rs b/tarpc/src/server/request_hook/before_and_after.rs index af37427af..dff0abe0b 100644 --- a/tarpc/src/server/request_hook/before_and_after.rs +++ b/tarpc/src/server/request_hook/before_and_after.rs @@ -46,7 +46,11 @@ where type Req = Req; type Resp = Resp; - async fn serve(self, ctx: &mut context::ServerContext, req: Req) -> Result { + async fn serve( + self, + ctx: &mut context::ServerContext, + req: Req, + ) -> Result { let HookThenServeThenHook { serve, mut hook, .. } = self; diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index ac2201933..709167751 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -13,7 +13,7 @@ use crate::{ use futures::{Sink, Stream, task::*}; use pin_project::pin_project; use std::{collections::VecDeque, io, pin::Pin, time::Instant}; -use tracing::{Span}; +use tracing::Span; #[pin_project] pub(crate) struct FakeChannel { diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index a319ef046..9607b5ef0 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -6,11 +6,11 @@ //! Transports backed by in-memory channels. -use futures::{Sink, Stream, task::*, SinkExt, TryStreamExt}; +use crate::Transport; +use futures::{Sink, SinkExt, Stream, TryStreamExt, task::*}; use pin_project::pin_project; use std::{error::Error, future, pin::Pin}; use tokio::sync::mpsc; -use crate::Transport; /// Errors that occur in the sending or receiving of messages over a channel. #[derive(thiserror::Error, Debug)] @@ -42,10 +42,14 @@ pub fn unbounded() -> ( /// Returns two mapped unbounded channel peers. Each [`Stream`] yields items sent through the other's /// [`Sink`]. -pub fn unbounded_mapped(mut f: F, mut g: G) -> ( +pub fn unbounded_mapped( + mut f: F, + mut g: G, +) -> ( impl Transport, impl Transport, -) where +) +where F: FnMut(ClientSinkItem) -> SerializedSinkItem + Send + 'static, G: FnMut(SerializedSinkItem) -> ServerSinkItem + Send + 'static, { @@ -179,15 +183,21 @@ impl Sink for Channel { #[cfg(all(test, feature = "tokio1"))] mod tests { - use crate::{ServerError, client::{self, RpcError}, context, server::{BaseChannel, incoming::Incoming, serve}, transport::{ - self, - channel::{Channel, UnboundedChannel}, - }, ClientMessage}; + use crate::context::{ClientContext, ServerContext, SharedContext}; + use crate::{ + ClientMessage, ServerError, + client::{self, RpcError}, + context, + server::{BaseChannel, incoming::Incoming, serve}, + transport::{ + self, + channel::{Channel, UnboundedChannel}, + }, + }; use assert_matches::assert_matches; use futures::{prelude::*, stream}; use std::io; use tracing::trace; - use crate::context::{ClientContext, ServerContext, SharedContext}; #[test] fn ensure_is_transport() { @@ -226,8 +236,12 @@ mod tests { let client = client::new(client::Config::default(), client_channel).spawn(); - let response1 = client.call(&mut context::ClientContext::current(), "123".into()).await; - let response2 = client.call(&mut context::ClientContext::current(), "abc".into()).await; + let response1 = client + .call(&mut context::ClientContext::current(), "123".into()) + .await; + let response2 = client + .call(&mut context::ClientContext::current(), "abc".into()) + .await; trace!("response1: {:?}, response2: {:?}", response1, response2); diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 73f6656d9..0ee12183d 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -1,11 +1,11 @@ use futures::prelude::*; -use tarpc::{serde_transport, ClientMessage}; +use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::{ClientMessage, serde_transport}; use tarpc::{ client, context, server::{BaseChannel, incoming::Incoming}, }; use tokio_serde::formats::Json; -use tarpc::context::{ClientContext, ServerContext, SharedContext}; #[tarpc::derive_serde] #[derive(Debug, PartialEq, Eq)] @@ -44,14 +44,22 @@ async fn test_call() -> anyhow::Result<()> { transport .take(1) .filter_map(|r| async { r.ok() }) - .map(|t| t.map_ok(|msg: ClientMessage| msg.map_context(|ctx| ServerContext::new(ctx)))) + .map(|t| { + t.map_ok(|msg: ClientMessage| { + msg.map_context(|ctx| ServerContext::new(ctx)) + }) + }) .map(BaseChannel::with_defaults) .execute(ColorServer.serve()) .map(|channel| channel.for_each(spawn)) .for_each(spawn), ); - let transport = serde_transport::tcp::connect(addr, Json::default).await?.with(|msg: ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); + let transport = serde_transport::tcp::connect(addr, Json::default) + .await? + .with(|msg: ClientMessage| { + future::ok(msg.map_context(|ctx| ctx.shared_context)) + }); let client = ColorProtocolClient::new(client::Config::default(), transport).spawn(); diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 30e4c0743..b6ba72026 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -4,9 +4,16 @@ use futures::{ prelude::*, }; use std::time::{Duration, Instant}; -use tarpc::{client::{self}, context, server::{BaseChannel, Channel, incoming::Incoming}, transport, transport::channel, ClientMessage}; -use tokio::join; use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::{ + ClientMessage, + client::{self}, + context, + server::{BaseChannel, Channel, incoming::Incoming}, + transport, + transport::channel, +}; +use tokio::join; #[tarpc_plugins::service] trait Service { @@ -43,7 +50,13 @@ async fn sequential() { })) .for_each(|response| response), ); - assert_eq!(client.call(&mut context::ClientContext::current(), 1).await.unwrap(), 2); + assert_eq!( + client + .call(&mut context::ClientContext::current(), 1) + .await + .unwrap(), + 2 + ); } #[tokio::test] @@ -71,7 +84,6 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { |msg: ClientMessage| msg.map_context(ServerContext::new), ); - // Set up a client that initiates a long-lived request. // The request will complete in error when the server drops the connection. tokio::spawn(async move { @@ -109,7 +121,13 @@ async fn serde_tcp() -> anyhow::Result<()> { transport .take(1) .filter_map(|r| async { r.ok() }) - .map(|t| t.map_ok(|msg: tarpc::ClientMessage| msg.map_context(|ctx| tarpc::context::ServerContext::new(ctx)))) + .map(|t| { + t.map_ok( + |msg: tarpc::ClientMessage| { + msg.map_context(|ctx| tarpc::context::ServerContext::new(ctx)) + }, + ) + }) .map(BaseChannel::with_defaults) .execute(Server.serve()) .map(|channel| channel.for_each(spawn)) @@ -117,10 +135,19 @@ async fn serde_tcp() -> anyhow::Result<()> { ); let transport = serde_transport::tcp::connect(addr, Json::default).await?; - let transport = transport.with(|msg: tarpc::ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); + let transport = transport.with( + |msg: tarpc::ClientMessage| { + future::ok(msg.map_context(|ctx| ctx.shared_context)) + }, + ); let client = ServiceClient::new(client::Config::default(), transport).spawn(); - assert_matches!(client.add(&mut context::ClientContext::current(), 1, 2).await, Ok(3)); + assert_matches!( + client + .add(&mut context::ClientContext::current(), 1, 2) + .await, + Ok(3) + ); assert_matches!( client.hey(&mut context::ClientContext::current(), "Tim".to_string()).await, Ok(ref s) if s == "Hey, Tim." @@ -143,7 +170,13 @@ async fn serde_uds() -> anyhow::Result<()> { transport .take(1) .filter_map(|r| async { r.ok() }) - .map(|t| t.map_ok(|msg: tarpc::ClientMessage| msg.map_context(|ctx| tarpc::context::ServerContext::new(ctx)))) + .map(|t| { + t.map_ok( + |msg: tarpc::ClientMessage| { + msg.map_context(|ctx| tarpc::context::ServerContext::new(ctx)) + }, + ) + }) .map(BaseChannel::with_defaults) .execute(Server.serve()) .map(|channel| channel.for_each(spawn)) @@ -151,12 +184,20 @@ async fn serde_uds() -> anyhow::Result<()> { ); let transport = serde_transport::unix::connect(&sock, Json::default).await?; - let transport = transport.with(|msg: tarpc::ClientMessage| future::ok(msg.map_context(|ctx| ctx.shared_context))); + let transport = transport.with( + |msg: tarpc::ClientMessage| { + future::ok(msg.map_context(|ctx| ctx.shared_context)) + }, + ); let client = ServiceClient::new(client::Config::default(), transport).spawn(); // Save results using socket so we can clean the socket even if our test assertions fail - let res1 = client.add(&mut context::ClientContext::current(), 1, 2).await; - let res2 = client.hey(&mut context::ClientContext::current(), "Tim".to_string()).await; + let res1 = client + .add(&mut context::ClientContext::current(), 1, 2) + .await; + let res2 = client + .hey(&mut context::ClientContext::current(), "Tim".to_string()) + .await; assert_matches!(res1, Ok(3)); assert_matches!(res2, Ok(ref s) if s == "Hey, Tim."); @@ -181,7 +222,6 @@ async fn concurrent() -> anyhow::Result<()> { .for_each(spawn), ); - let client = ServiceClient::new(client::Config::default(), tx).spawn(); let mut context = context::ClientContext::current(); @@ -295,12 +335,18 @@ async fn counter() -> anyhow::Result<()> { while let Some(Ok(request)) = requests.next().await { request.execute(counter.serve()).await; - }; + } }); let client = CounterClient::new(client::Config::default(), tx).spawn(); - assert_matches!(client.count(&mut context::ClientContext::current()).await, Ok(1)); - assert_matches!(client.count(&mut context::ClientContext::current()).await, Ok(2)); + assert_matches!( + client.count(&mut context::ClientContext::current()).await, + Ok(1) + ); + assert_matches!( + client.count(&mut context::ClientContext::current()).await, + Ok(2) + ); Ok(()) } From 15b84e4f14ddcafffe3a375b428a9a528a85076b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Sun, 23 Nov 2025 21:07:59 +0100 Subject: [PATCH 07/26] run cargo clippy --- example-service/src/server.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 7e29da291..c1cf618b9 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -65,7 +65,7 @@ async fn main() -> anyhow::Result<()> { .filter_map(|r| future::ready(r.ok())) .map(|t| { t.map_ok(|msg: ClientMessage| { - msg.map_context(|ctx| ServerContext::new(ctx)) + msg.map_context(ServerContext::new) }) }) .map(server::BaseChannel::with_defaults) From 117ae5713324a8b2887c80388294e2cd76df17d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Mon, 24 Nov 2025 17:54:56 +0100 Subject: [PATCH 08/26] simplify api --- example-service/src/client.rs | 5 +-- example-service/src/server.rs | 7 +--- plugins/src/lib.rs | 9 +---- tarpc/examples/compression.rs | 9 ++--- tarpc/examples/custom_transport.rs | 8 ++-- tarpc/examples/pubsub.rs | 24 ++++-------- tarpc/examples/readme.rs | 6 +-- tarpc/examples/tls_over_tcp.rs | 9 ++--- tarpc/examples/tracing.rs | 23 ++++-------- tarpc/src/lib.rs | 6 +-- tarpc/src/server.rs | 21 ++--------- tarpc/src/server/incoming.rs | 5 +-- tarpc/src/transport/channel.rs | 39 ++++++++++++++----- tarpc/tests/dataservice.rs | 11 ++---- tarpc/tests/service_functional.rs | 60 ++++++------------------------ 15 files changed, 81 insertions(+), 161 deletions(-) diff --git a/example-service/src/client.rs b/example-service/src/client.rs index e425c9eb2..e2c327dfb 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -9,6 +9,7 @@ use futures::{SinkExt, future}; use service::{WorldClient, init_tracing}; use std::{net::SocketAddr, time::Duration}; use tarpc::context::ClientContext; +use tarpc::transport::channel::map_client_context_to_shared; use tarpc::{client, tokio_serde::formats::Json}; use tokio::time::sleep; use tracing::Instrument; @@ -33,9 +34,7 @@ async fn main() -> anyhow::Result<()> { let transport = transport .await? - .with(|msg: tarpc::ClientMessage| { - future::ok(msg.map_context(|ctx| ctx.shared_context)) - }); + .with(|msg| future::ok(map_client_context_to_shared(msg))); // WorldClient is generated by the service attribute. It has a constructor `new` that takes a // config and any Transport as input. diff --git a/example-service/src/server.rs b/example-service/src/server.rs index c1cf618b9..7871c5f86 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -22,6 +22,7 @@ use tarpc::{ tokio_serde::formats::Json, }; use tokio::time; +use tarpc::transport::channel::map_shared_context_to_server; #[derive(Parser)] struct Flags { @@ -63,11 +64,7 @@ async fn main() -> anyhow::Result<()> { listener // Ignore accept errors. .filter_map(|r| future::ready(r.ok())) - .map(|t| { - t.map_ok(|msg: ClientMessage| { - msg.map_context(ServerContext::new) - }) - }) + .map(|t| t.map_ok(map_shared_context_to_server)) .map(server::BaseChannel::with_defaults) // Limit channels to 1 per IP. .max_channels_per_key(1, |t| t.transport().get_ref().peer_addr().unwrap().ip()) diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index bc52cf849..21e3cd35b 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -393,14 +393,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// let resp = CalculatorResponse::Add(12); /// /// // This could be any transport. -/// let (client_side, server_side) = transport::channel::unbounded(); -/// -/// let client_side = client_side.with(|msg: tarpc::ClientMessage| async move { -/// Ok(msg.map_context(|ctx| ctx.shared_context)) -/// }); -/// let server_side = server_side.map_ok(|msg: tarpc::ClientMessage| -/// msg.map_context(tarpc::context::ServerContext::new) -/// ); +/// let (client_side, server_side) = transport::channel::unbounded_for_client_server_context(); /// /// // A client can be made like so: /// let client = CalculatorClient::new(client::Config::default(), client_side); diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index 46300999e..4e09d625e 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -10,6 +10,7 @@ use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; use std::{io, io::Read, io::Write}; use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::transport::channel::{map_client_context_to_shared, map_shared_context_to_server}; use tarpc::{ ClientMessage, client, context, serde_transport::tcp, @@ -126,9 +127,7 @@ async fn main() -> anyhow::Result<()> { tokio::spawn(async move { let transport = incoming.next().await.unwrap().unwrap(); let transport = add_compression(transport); - let transport = transport.map_ok(|msg: ClientMessage| { - msg.map_context(|ctx| ServerContext::new(ctx)) - }); + let transport = transport.map_ok(map_shared_context_to_server); BaseChannel::with_defaults(transport) .execute(HelloServer.serve()) .for_each(spawn) @@ -137,9 +136,7 @@ async fn main() -> anyhow::Result<()> { let transport = tcp::connect(addr, Bincode::default).await?; let transport = add_compression(transport); - let transport = transport.with(|msg: ClientMessage| { - future::ok(msg.map_context(|ctx| ctx.shared_context)) - }); + let transport = transport.with(|msg| future::ok(map_client_context_to_shared(msg))); let client = WorldClient::new(client::Config::default(), transport).spawn(); println!( diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index a1b1e4410..350828743 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -9,6 +9,7 @@ use tarpc::context::{ClientContext, ServerContext, SharedContext}; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; +use tarpc::transport::channel::{map_client_context_to_shared, map_shared_context_to_server}; use tarpc::{ClientMessage, serde_transport as transport}; use tokio::net::{UnixListener, UnixStream}; @@ -39,8 +40,7 @@ async fn main() -> anyhow::Result<()> { let (conn, _addr) = listener.accept().await.unwrap(); let framed = codec_builder.new_framed(conn); let transport = transport::new(framed, Bincode::default()); - let transport = transport - .map_ok(|c: ClientMessage| c.map_context(ServerContext::new)); + let transport = transport.map_ok(map_shared_context_to_server); let fut = BaseChannel::with_defaults(transport) .execute(Service.serve()) @@ -51,9 +51,7 @@ async fn main() -> anyhow::Result<()> { let conn = UnixStream::connect(bind_addr).await?; let transport = transport::new(codec_builder.new_framed(conn), Bincode::default()); - let transport = transport.with(|msg: ClientMessage| { - future::ok(msg.map_context(|ctx| ctx.shared_context)) - }); + let transport = transport.with(|msg| future::ok(map_client_context_to_shared(msg))); PingServiceClient::new(Default::default(), transport) .spawn() .ping(&mut ClientContext::current()) diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 16195ef3f..6a6ce5723 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -49,6 +49,7 @@ use std::{ }; use subscriber::Subscriber as _; use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::transport::channel::{map_client_context_to_shared, map_shared_context_to_server}; use tarpc::{ ClientMessage, client, context, serde_transport::tcp, @@ -105,9 +106,7 @@ impl Subscriber { ) -> anyhow::Result { let publisher = tcp::connect(publisher_addr, Json::default).await?; let local_addr = publisher.local_addr()?; - let publisher = publisher.map_ok(|msg: ClientMessage| { - msg.map_context(|ctx| ServerContext::new(ctx)) - }); + let publisher = publisher.map_ok(map_shared_context_to_server); let mut handler = server::BaseChannel::with_defaults(publisher).requests(); let subscriber = Subscriber { local_addr, topics }; // The first request is for the topics being subscribed to. @@ -168,9 +167,7 @@ impl Publisher { let publisher = connecting_publishers.next().await.unwrap().unwrap(); info!(publisher.peer_addr = ?publisher.peer_addr(), "publisher connected."); - let publisher = publisher.map_ok(|msg: ClientMessage| { - msg.map_context(|ctx| ServerContext::new(ctx)) - }); + let publisher = publisher.map_ok(map_shared_context_to_server); server::BaseChannel::with_defaults(publisher) .execute(self.serve()) @@ -191,12 +188,7 @@ impl Publisher { tokio::spawn(async move { while let Some(conn) = connecting_subscribers.next().await { let subscriber_addr = conn.peer_addr().unwrap(); - let conn = conn.with( - |msg: tarpc::ClientMessage| { - future::ok(msg.map_context(|ctx| ctx.shared_context)) - }, - ); - + let conn = conn.with(|msg| future::ok(map_client_context_to_shared(msg))); let tarpc::client::NewClient { client: subscriber, dispatch, @@ -362,11 +354,9 @@ async fn main() -> anyhow::Result<()> { let publisher = publisher::PublisherClient::new( client::Config::default(), - tcp::connect(addrs.publisher, Json::default).await?.with( - |msg: tarpc::ClientMessage| { - future::ok(msg.map_context(|ctx| ctx.shared_context)) - }, - ), + tcp::connect(addrs.publisher, Json::default) + .await? + .with(|msg| future::ok(map_client_context_to_shared(msg))), ) .spawn(); diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index 44ae497ff..b20d4ab91 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -36,10 +36,8 @@ async fn spawn(fut: impl Future + Send + 'static) { #[tokio::main] async fn main() -> anyhow::Result<()> { - let (client_transport, server_transport) = transport::channel::unbounded_mapped( - |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), - |msg: ClientMessage| msg.map_context(ServerContext::new), - ); + let (client_transport, server_transport) = + transport::channel::unbounded_for_client_server_context(); let server = server::BaseChannel::with_defaults(server_transport); tokio::spawn(server.execute(HelloServer.serve()).for_each(spawn)); diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index bac4a8048..56583b05c 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -15,6 +15,7 @@ use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; +use tarpc::transport::channel::{map_client_context_to_shared, map_shared_context_to_server}; use tokio::net::TcpListener; use tokio::net::TcpStream; use tokio_rustls::rustls::{ @@ -114,9 +115,7 @@ async fn main() -> anyhow::Result<()> { let framed = codec_builder.new_framed(tls_stream); let transport = transport::new(framed, Bincode::default()); - let transport = transport.map_ok(|c: tarpc::ClientMessage| { - c.map_context(|ctx| ServerContext::new(ctx)) - }); + let transport = transport.map_ok(map_shared_context_to_server); let fut = BaseChannel::with_defaults(transport) .execute(Service.serve()) @@ -146,9 +145,7 @@ async fn main() -> anyhow::Result<()> { let stream = connector.connect(domain, stream).await?; let transport = transport::new(codec_builder.new_framed(stream), Bincode::default()); - let transport = transport.with(|msg: tarpc::ClientMessage| { - future::ok(msg.map_context(|ctx| ctx.shared_context)) - }); + let transport = transport.with(|msg| future::ok(map_client_context_to_shared(msg))); let answer = PingServiceClient::new(Default::default(), transport) .spawn() .ping(&mut ClientContext::current()) diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index 52b068bc8..acf631be8 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -20,6 +20,7 @@ use std::{ }, }; use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::transport::channel::{map_client_context_to_shared, map_shared_context_to_server}; use tarpc::{ ClientMessage, RequestName, Response, ServerError, Transport, client::{ @@ -175,17 +176,12 @@ async fn main() -> anyhow::Result<()> { .serving(AddServer.serve()); let add_server = add_listener1 .chain(add_listener2) - .map(|t| { - t.map_ok(|msg: ClientMessage| { - msg.map_context(|ctx| ServerContext::new(ctx)) - }) - }) + .map(|t| t.map_ok(map_shared_context_to_server)) .map(BaseChannel::with_defaults); tokio::spawn(spawn_incoming(add_server.execute(server))); - let map_context = |msg: ClientMessage| { - future::ok(msg.map_context(|ctx| ctx.shared_context)) - }; + let map_context = + |msg: ClientMessage| future::ok(map_client_context_to_shared(msg)); let add_client = add::AddClient::from(make_stub([ tarpc::serde_transport::tcp::connect(addr1, Json::default) @@ -199,20 +195,15 @@ async fn main() -> anyhow::Result<()> { let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) .await? .filter_map(|r| future::ready(r.ok())) - .map(|t| { - t.map_ok(|msg: ClientMessage| { - msg.map_context(|ctx| ServerContext::new(ctx)) - }) - }); + .map(|t| t.map_ok(map_shared_context_to_server)); let addr = double_listener.get_ref().get_ref().local_addr(); let double_server = double_listener.map(BaseChannel::with_defaults).take(1); let server = DoubleServer { add_client }.serve(); tokio::spawn(spawn_incoming(double_server.execute(server))); let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; - let to_double_server = to_double_server.with(|msg: ClientMessage| { - future::ok(msg.map_context(|ctx| ctx.shared_context)) - }); + let to_double_server = + to_double_server.with(|msg| future::ok(map_client_context_to_shared(msg))); let double_client = double::DoubleClient::new(client::Config::default(), to_double_server).spawn(); diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index cf3423eb5..6f7e08fc0 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -170,11 +170,7 @@ //! # #[cfg(feature = "tokio1")] //! #[tokio::main] //! async fn main() -> anyhow::Result<()> { -//! let (client_transport, server_transport) = channel::unbounded_mapped( -//! |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), -//! |msg: ClientMessage| msg.map_context(ServerContext::new), -//! ); -//! +//! let (client_transport, server_transport) = channel::unbounded_for_client_server_context(); //! let server = server::BaseChannel::with_defaults(server_transport); //! tokio::spawn( //! server.execute(HelloServer.serve()) diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 559d80ba8..87ac89681 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -363,10 +363,7 @@ where /// /// #[tokio::main] /// async fn main() { - /// let (tx, rx) = transport::channel::unbounded_mapped( - /// |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), - /// |msg: ClientMessage| msg.map_context(ServerContext::new), - /// ); + /// let (tx, rx) = transport::channel::unbounded_for_client_server_context(); /// let server = BaseChannel::new(server::Config::default(), rx); /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); /// tokio::spawn(dispatch); @@ -410,10 +407,7 @@ where /// # #[cfg(feature = "tokio1")] /// #[tokio::main] /// async fn main() { - /// let (tx, rx) = transport::channel::unbounded_mapped( - /// |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), - /// |msg: ClientMessage| msg.map_context(ServerContext::new), - /// ); + /// let (tx, rx) = transport::channel::unbounded_for_client_server_context(); /// let client = client::new(client::Config::default(), tx).spawn(); /// let channel = BaseChannel::with_defaults(rx); /// tokio::spawn( @@ -764,11 +758,7 @@ where /// # #[cfg(feature = "tokio1")] /// #[tokio::main] /// async fn main() { - /// let (tx, rx) = transport::channel::unbounded_mapped( - /// |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), - /// |msg: ClientMessage| msg.map_context(ServerContext::new), - /// ); - /// + /// let (tx, rx) = transport::channel::unbounded_for_client_server_context(); /// let requests = BaseChannel::new(server::Config::default(), rx).requests(); /// let client = client::new(client::Config::default(), tx).spawn(); /// tokio::spawn( @@ -874,10 +864,7 @@ impl InFlightRequest { /// /// #[tokio::main] /// async fn main() { - /// let (tx, rx) = transport::channel::unbounded_mapped( - /// |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), - /// |msg: ClientMessage| msg.map_context(ServerContext::new), - /// ); + /// let (tx, rx) = transport::channel::unbounded_for_client_server_context(); /// let server = BaseChannel::new(server::Config::default(), rx); /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); /// tokio::spawn(dispatch); diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index ad91f0c19..1868cbe47 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -59,10 +59,7 @@ where /// /// #[tokio::main] /// async fn main() { -/// let (tx, rx) = transport::channel::unbounded_mapped( -/// |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), -/// |msg: ClientMessage| msg.map_context(ServerContext::new), -/// ); +/// let (tx, rx) = transport::channel::unbounded_for_client_server_context(); /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); /// tokio::spawn(dispatch); /// diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 9607b5ef0..65e987d02 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -6,7 +6,8 @@ //! Transports backed by in-memory channels. -use crate::Transport; +use crate::context::{ClientContext, ServerContext, SharedContext}; +use crate::{ClientMessage, Transport}; use futures::{Sink, SinkExt, Stream, TryStreamExt, task::*}; use pin_project::pin_project; use std::{error::Error, future, pin::Pin}; @@ -50,8 +51,8 @@ pub fn unbounded_mapped, ) where - F: FnMut(ClientSinkItem) -> SerializedSinkItem + Send + 'static, - G: FnMut(SerializedSinkItem) -> ServerSinkItem + Send + 'static, + F: FnMut(ClientSinkItem) -> SerializedSinkItem, + G: FnMut(SerializedSinkItem) -> ServerSinkItem, { let (client, server) = unbounded(); @@ -61,6 +62,29 @@ where (client, server) } +/// Convenience functino to return two mapped unbounded channel peers for a basechannel and a client implementation. Each [`Stream`] yields items sent through the other's +/// [`Sink`]. +pub fn unbounded_for_client_server_context() -> ( + impl Transport, Resp>, + impl Transport>, +) { + unbounded_mapped(map_client_context_to_shared, map_shared_context_to_server) +} + +/// Convenience function to map a ClientMessage with ClientContext to one with SharedContext. +pub fn map_client_context_to_shared( + msg: ClientMessage, +) -> ClientMessage { + msg.map_context(|ctx| ctx.shared_context) +} + +/// Convenience function to map a ClientMessage with SharedContext to one with ServerContext. +pub fn map_shared_context_to_server( + msg: ClientMessage, +) -> ClientMessage { + msg.map_context(ServerContext::new) +} + /// A bi-directional channel backed by an [`UnboundedSender`](mpsc::UnboundedSender) /// and [`UnboundedReceiver`](mpsc::UnboundedReceiver). #[derive(Debug)] @@ -183,9 +207,8 @@ impl Sink for Channel { #[cfg(all(test, feature = "tokio1"))] mod tests { - use crate::context::{ClientContext, ServerContext, SharedContext}; use crate::{ - ClientMessage, ServerError, + ServerError, client::{self, RpcError}, context, server::{BaseChannel, incoming::Incoming, serve}, @@ -210,10 +233,8 @@ mod tests { async fn integration() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (client_channel, server_channel) = transport::channel::unbounded_mapped( - |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), - |msg: ClientMessage| msg.map_context(ServerContext::new), - ); + let (client_channel, server_channel) = + transport::channel::unbounded_for_client_server_context(); tokio::spawn( stream::once(future::ready(server_channel)) diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 0ee12183d..1ac04af13 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -1,5 +1,6 @@ use futures::prelude::*; use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::transport::channel::{map_client_context_to_shared, map_shared_context_to_server}; use tarpc::{ClientMessage, serde_transport}; use tarpc::{ client, context, @@ -44,11 +45,7 @@ async fn test_call() -> anyhow::Result<()> { transport .take(1) .filter_map(|r| async { r.ok() }) - .map(|t| { - t.map_ok(|msg: ClientMessage| { - msg.map_context(|ctx| ServerContext::new(ctx)) - }) - }) + .map(|t| t.map_ok(map_shared_context_to_server)) .map(BaseChannel::with_defaults) .execute(ColorServer.serve()) .map(|channel| channel.for_each(spawn)) @@ -57,9 +54,7 @@ async fn test_call() -> anyhow::Result<()> { let transport = serde_transport::tcp::connect(addr, Json::default) .await? - .with(|msg: ClientMessage| { - future::ok(msg.map_context(|ctx| ctx.shared_context)) - }); + .with(|msg| future::ok(map_client_context_to_shared(msg))); let client = ColorProtocolClient::new(client::Config::default(), transport).spawn(); diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index b6ba72026..ebebef660 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -5,6 +5,7 @@ use futures::{ }; use std::time::{Duration, Instant}; use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::transport::channel::{map_client_context_to_shared, map_shared_context_to_server}; use tarpc::{ ClientMessage, client::{self}, @@ -36,10 +37,7 @@ impl Service for Server { #[tokio::test] async fn sequential() { - let (tx, rx) = transport::channel::unbounded_mapped( - |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), - |msg: ClientMessage| msg.map_context(ServerContext::new), - ); + let (tx, rx) = transport::channel::unbounded_for_client_server_context(); let client = client::new(client::Config::default(), tx).spawn(); let channel = BaseChannel::with_defaults(rx); @@ -79,10 +77,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = transport::channel::unbounded_mapped( - |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), - |msg: ClientMessage| msg.map_context(ServerContext::new), - ); + let (tx, rx) = transport::channel::unbounded_for_client_server_context(); // Set up a client that initiates a long-lived request. // The request will complete in error when the server drops the connection. @@ -121,13 +116,7 @@ async fn serde_tcp() -> anyhow::Result<()> { transport .take(1) .filter_map(|r| async { r.ok() }) - .map(|t| { - t.map_ok( - |msg: tarpc::ClientMessage| { - msg.map_context(|ctx| tarpc::context::ServerContext::new(ctx)) - }, - ) - }) + .map(|t| t.map_ok(map_shared_context_to_server)) .map(BaseChannel::with_defaults) .execute(Server.serve()) .map(|channel| channel.for_each(spawn)) @@ -135,11 +124,7 @@ async fn serde_tcp() -> anyhow::Result<()> { ); let transport = serde_transport::tcp::connect(addr, Json::default).await?; - let transport = transport.with( - |msg: tarpc::ClientMessage| { - future::ok(msg.map_context(|ctx| ctx.shared_context)) - }, - ); + let transport = transport.with(|msg| future::ok(map_client_context_to_shared(msg))); let client = ServiceClient::new(client::Config::default(), transport).spawn(); assert_matches!( @@ -170,13 +155,7 @@ async fn serde_uds() -> anyhow::Result<()> { transport .take(1) .filter_map(|r| async { r.ok() }) - .map(|t| { - t.map_ok( - |msg: tarpc::ClientMessage| { - msg.map_context(|ctx| tarpc::context::ServerContext::new(ctx)) - }, - ) - }) + .map(|t| t.map_ok(map_shared_context_to_server)) .map(BaseChannel::with_defaults) .execute(Server.serve()) .map(|channel| channel.for_each(spawn)) @@ -184,11 +163,8 @@ async fn serde_uds() -> anyhow::Result<()> { ); let transport = serde_transport::unix::connect(&sock, Json::default).await?; - let transport = transport.with( - |msg: tarpc::ClientMessage| { - future::ok(msg.map_context(|ctx| ctx.shared_context)) - }, - ); + let transport = transport.with(|msg| future::ok(map_client_context_to_shared(msg))); + let client = ServiceClient::new(client::Config::default(), transport).spawn(); // Save results using socket so we can clean the socket even if our test assertions fail @@ -209,10 +185,7 @@ async fn serde_uds() -> anyhow::Result<()> { async fn concurrent() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = transport::channel::unbounded_mapped( - |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), - |msg: ClientMessage| msg.map_context(ServerContext::new), - ); + let (tx, rx) = transport::channel::unbounded_for_client_server_context(); tokio::spawn( stream::once(ready(rx)) @@ -242,10 +215,7 @@ async fn concurrent() -> anyhow::Result<()> { async fn concurrent_join() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = transport::channel::unbounded_mapped( - |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), - |msg: ClientMessage| msg.map_context(ServerContext::new), - ); + let (tx, rx) = transport::channel::unbounded_for_client_server_context(); tokio::spawn( stream::once(ready(rx)) @@ -282,10 +252,7 @@ async fn spawn(fut: impl Future + Send + 'static) { async fn concurrent_join_all() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = transport::channel::unbounded_mapped( - |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), - |msg: ClientMessage| msg.map_context(ServerContext::new), - ); + let (tx, rx) = transport::channel::unbounded_for_client_server_context(); tokio::spawn( BaseChannel::with_defaults(rx) @@ -324,10 +291,7 @@ async fn counter() -> anyhow::Result<()> { } } - let (tx, rx) = channel::unbounded_mapped( - |msg: ClientMessage| msg.map_context(|ctx| ctx.shared_context), - |msg: ClientMessage| msg.map_context(ServerContext::new), - ); + let (tx, rx) = channel::unbounded_for_client_server_context(); tokio::task::spawn(async move { let mut requests = BaseChannel::with_defaults(rx).requests(); From b2eb13b72a7c80d43acc1fee40fe6eadbdec9185 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Tue, 25 Nov 2025 13:02:33 +0100 Subject: [PATCH 09/26] allow transport to access server context on response as well --- example-service/src/client.rs | 6 +- example-service/src/server.rs | 9 +- plugins/src/lib.rs | 2 +- tarpc/examples/compression.rs | 9 +- tarpc/examples/custom_transport.rs | 8 +- tarpc/examples/pubsub.rs | 12 +- tarpc/examples/tls_over_tcp.rs | 6 +- tarpc/examples/tracing.rs | 26 ++--- tarpc/src/client.rs | 27 +++-- tarpc/src/lib.rs | 18 ++- tarpc/src/server.rs | 77 +++++++----- .../src/server/limits/requests_per_channel.rs | 27 +++-- tarpc/src/server/testing.rs | 18 ++- tarpc/src/transport/channel.rs | 110 ++++++++++++++++-- tarpc/tests/dataservice.rs | 9 +- tarpc/tests/service_functional.rs | 11 +- 16 files changed, 256 insertions(+), 119 deletions(-) diff --git a/example-service/src/client.rs b/example-service/src/client.rs index e2c327dfb..71c9704ea 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -9,7 +9,7 @@ use futures::{SinkExt, future}; use service::{WorldClient, init_tracing}; use std::{net::SocketAddr, time::Duration}; use tarpc::context::ClientContext; -use tarpc::transport::channel::map_client_context_to_shared; +use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; use tarpc::{client, tokio_serde::formats::Json}; use tokio::time::sleep; use tracing::Instrument; @@ -32,9 +32,7 @@ async fn main() -> anyhow::Result<()> { let mut transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default); transport.config_mut().max_frame_length(usize::MAX); - let transport = transport - .await? - .with(|msg| future::ok(map_client_context_to_shared(msg))); + let transport = map_transport_to_client(transport.await?); // WorldClient is generated by the service attribute. It has a constructor `new` that takes a // config and any Transport as input. diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 7871c5f86..fe61904b9 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -11,18 +11,19 @@ use rand::{ thread_rng, }; use service::{World, init_tracing}; +use std::ops::Deref; use std::{ net::{IpAddr, Ipv6Addr, SocketAddr}, time::Duration, }; use tarpc::context::{ServerContext, SharedContext}; +use tarpc::transport::channel::{map_transport_to_server}; use tarpc::{ ClientMessage, context, server::{self, Channel, incoming::Incoming}, tokio_serde::formats::Json, }; use tokio::time; -use tarpc::transport::channel::map_shared_context_to_server; #[derive(Parser)] struct Flags { @@ -64,14 +65,14 @@ async fn main() -> anyhow::Result<()> { listener // Ignore accept errors. .filter_map(|r| future::ready(r.ok())) - .map(|t| t.map_ok(map_shared_context_to_server)) + .map(map_transport_to_server) .map(server::BaseChannel::with_defaults) // Limit channels to 1 per IP. - .max_channels_per_key(1, |t| t.transport().get_ref().peer_addr().unwrap().ip()) + .max_channels_per_key(1, |t| t.transport().get_ref().get_ref().get_ref().peer_addr().unwrap().ip()) // serve is generated by the service attribute. It takes as input any type implementing // the generated World trait. .map(|channel| { - let server = HelloServer(channel.transport().get_ref().peer_addr().unwrap()); + let server = HelloServer(channel.transport().get_ref().get_ref().get_ref().peer_addr().unwrap()); channel.execute(server.serve()).for_each(spawn) }) // Max 10 channels. diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 21e3cd35b..71d7d3c80 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -739,7 +739,7 @@ impl ServiceGenerator<'_> { ::tarpc::client::RequestDispatch<#request_ident, #response_ident, T> > where - T: ::tarpc::Transport<::tarpc::ClientMessage<::tarpc::context::ClientContext, #request_ident>, ::tarpc::Response<#response_ident>> + T: ::tarpc::Transport<::tarpc::ClientMessage<::tarpc::context::ClientContext, #request_ident>, ::tarpc::Response<::tarpc::context::ClientContext, #response_ident>> { let new_client = ::tarpc::client::new(config, transport); ::tarpc::client::NewClient { diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index 4e09d625e..0801ce9f4 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -9,10 +9,9 @@ use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt, prelude::*}; use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; use std::{io, io::Read, io::Write}; -use tarpc::context::{ClientContext, ServerContext, SharedContext}; -use tarpc::transport::channel::{map_client_context_to_shared, map_shared_context_to_server}; +use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; use tarpc::{ - ClientMessage, client, context, + client, context, serde_transport::tcp, server::{BaseChannel, Channel}, tokio_serde::formats::Bincode, @@ -127,7 +126,7 @@ async fn main() -> anyhow::Result<()> { tokio::spawn(async move { let transport = incoming.next().await.unwrap().unwrap(); let transport = add_compression(transport); - let transport = transport.map_ok(map_shared_context_to_server); + let transport = map_transport_to_server(transport); BaseChannel::with_defaults(transport) .execute(HelloServer.serve()) .for_each(spawn) @@ -136,7 +135,7 @@ async fn main() -> anyhow::Result<()> { let transport = tcp::connect(addr, Bincode::default).await?; let transport = add_compression(transport); - let transport = transport.with(|msg| future::ok(map_client_context_to_shared(msg))); + let transport = map_transport_to_client(transport); let client = WorldClient::new(client::Config::default(), transport).spawn(); println!( diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index 350828743..7c23a1fa7 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -6,11 +6,11 @@ use futures::prelude::*; use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; -use tarpc::transport::channel::{map_client_context_to_shared, map_shared_context_to_server}; -use tarpc::{ClientMessage, serde_transport as transport}; +use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; use tokio::net::{UnixListener, UnixStream}; #[tarpc::service] @@ -40,7 +40,7 @@ async fn main() -> anyhow::Result<()> { let (conn, _addr) = listener.accept().await.unwrap(); let framed = codec_builder.new_framed(conn); let transport = transport::new(framed, Bincode::default()); - let transport = transport.map_ok(map_shared_context_to_server); + let transport = map_transport_to_server(transport); let fut = BaseChannel::with_defaults(transport) .execute(Service.serve()) @@ -51,7 +51,7 @@ async fn main() -> anyhow::Result<()> { let conn = UnixStream::connect(bind_addr).await?; let transport = transport::new(codec_builder.new_framed(conn), Bincode::default()); - let transport = transport.with(|msg| future::ok(map_client_context_to_shared(msg))); + let transport = map_transport_to_client(transport); PingServiceClient::new(Default::default(), transport) .spawn() .ping(&mut ClientContext::current()) diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 6a6ce5723..8094c490d 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -49,7 +49,7 @@ use std::{ }; use subscriber::Subscriber as _; use tarpc::context::{ClientContext, ServerContext, SharedContext}; -use tarpc::transport::channel::{map_client_context_to_shared, map_shared_context_to_server}; +use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; use tarpc::{ ClientMessage, client, context, serde_transport::tcp, @@ -106,7 +106,7 @@ impl Subscriber { ) -> anyhow::Result { let publisher = tcp::connect(publisher_addr, Json::default).await?; let local_addr = publisher.local_addr()?; - let publisher = publisher.map_ok(map_shared_context_to_server); + let publisher = map_transport_to_server(publisher); let mut handler = server::BaseChannel::with_defaults(publisher).requests(); let subscriber = Subscriber { local_addr, topics }; // The first request is for the topics being subscribed to. @@ -167,7 +167,7 @@ impl Publisher { let publisher = connecting_publishers.next().await.unwrap().unwrap(); info!(publisher.peer_addr = ?publisher.peer_addr(), "publisher connected."); - let publisher = publisher.map_ok(map_shared_context_to_server); + let publisher = map_transport_to_server(publisher); server::BaseChannel::with_defaults(publisher) .execute(self.serve()) @@ -188,7 +188,7 @@ impl Publisher { tokio::spawn(async move { while let Some(conn) = connecting_subscribers.next().await { let subscriber_addr = conn.peer_addr().unwrap(); - let conn = conn.with(|msg| future::ok(map_client_context_to_shared(msg))); + let conn = map_transport_to_client(conn); let tarpc::client::NewClient { client: subscriber, dispatch, @@ -354,9 +354,7 @@ async fn main() -> anyhow::Result<()> { let publisher = publisher::PublisherClient::new( client::Config::default(), - tcp::connect(addrs.publisher, Json::default) - .await? - .with(|msg| future::ok(map_client_context_to_shared(msg))), + map_transport_to_client(tcp::connect(addrs.publisher, Json::default).await?), ) .spawn(); diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index 56583b05c..2d90650a5 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -15,7 +15,7 @@ use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; -use tarpc::transport::channel::{map_client_context_to_shared, map_shared_context_to_server}; +use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; use tokio::net::TcpListener; use tokio::net::TcpStream; use tokio_rustls::rustls::{ @@ -115,7 +115,7 @@ async fn main() -> anyhow::Result<()> { let framed = codec_builder.new_framed(tls_stream); let transport = transport::new(framed, Bincode::default()); - let transport = transport.map_ok(map_shared_context_to_server); + let transport = map_transport_to_server(transport); let fut = BaseChannel::with_defaults(transport) .execute(Service.serve()) @@ -145,7 +145,7 @@ async fn main() -> anyhow::Result<()> { let stream = connector.connect(domain, stream).await?; let transport = transport::new(codec_builder.new_framed(stream), Bincode::default()); - let transport = transport.with(|msg| future::ok(map_client_context_to_shared(msg))); + let transport = map_transport_to_client(transport); let answer = PingServiceClient::new(Default::default(), transport) .spawn() .ping(&mut ClientContext::current()) diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index acf631be8..0930aae1d 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -20,7 +20,7 @@ use std::{ }, }; use tarpc::context::{ClientContext, ServerContext, SharedContext}; -use tarpc::transport::channel::{map_client_context_to_shared, map_shared_context_to_server}; +use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; use tarpc::{ ClientMessage, RequestName, Response, ServerError, Transport, client::{ @@ -126,8 +126,10 @@ where } fn make_stub( - backends: [impl Transport>, Response> + Send + Sync + 'static; - N], + backends: [impl Transport>, Response> + + Send + + Sync + + 'static; N], ) -> retry::Retry< impl Fn(&Result, u32) -> bool + Clone, load_balance::RoundRobin, Resp>>, @@ -176,34 +178,26 @@ async fn main() -> anyhow::Result<()> { .serving(AddServer.serve()); let add_server = add_listener1 .chain(add_listener2) - .map(|t| t.map_ok(map_shared_context_to_server)) + .map(map_transport_to_server) .map(BaseChannel::with_defaults); tokio::spawn(spawn_incoming(add_server.execute(server))); - let map_context = - |msg: ClientMessage| future::ok(map_client_context_to_shared(msg)); - let add_client = add::AddClient::from(make_stub([ - tarpc::serde_transport::tcp::connect(addr1, Json::default) - .await? - .with(map_context), - tarpc::serde_transport::tcp::connect(addr2, Json::default) - .await? - .with(map_context), + map_transport_to_client(tarpc::serde_transport::tcp::connect(addr1, Json::default).await?), + map_transport_to_client(tarpc::serde_transport::tcp::connect(addr2, Json::default).await?), ])); let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) .await? .filter_map(|r| future::ready(r.ok())) - .map(|t| t.map_ok(map_shared_context_to_server)); + .map(map_transport_to_server); let addr = double_listener.get_ref().get_ref().local_addr(); let double_server = double_listener.map(BaseChannel::with_defaults).take(1); let server = DoubleServer { add_client }.serve(); tokio::spawn(spawn_incoming(double_server.execute(server))); let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; - let to_double_server = - to_double_server.with(|msg| future::ok(map_client_context_to_shared(msg))); + let to_double_server = map_transport_to_client(to_double_server); let double_client = double::DoubleClient::new(client::Config::default(), to_double_server).spawn(); diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index ebcb69db1..125f3ad4a 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -244,7 +244,7 @@ pub fn new( transport: C, ) -> NewClient, RequestDispatch> where - C: Transport, Response>, + C: Transport, Response>, { let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer); let (cancellation, canceled_requests) = cancellations(); @@ -292,7 +292,7 @@ pub struct RequestDispatch { impl RequestDispatch where - C: Transport, Response>, + C: Transport, Response>, { fn in_flight_requests<'a>( self: &'a mut Pin<&mut Self>, @@ -577,7 +577,7 @@ where } /// Sends a server response to the client task that initiated the associated request. - fn complete(mut self: Pin<&mut Self>, response: Response) -> bool { + fn complete(mut self: Pin<&mut Self>, response: Response) -> bool { if let Some(span) = self.in_flight_requests().complete_request( response.request_id, response.message.map_err(RpcError::Server), @@ -657,7 +657,7 @@ where impl Future for RequestDispatch where - C: Transport, Response>, + C: Transport, Response>, { type Output = Result<(), ChannelError>; @@ -746,6 +746,7 @@ mod tests { server_channel .send(Response { request_id: 0, + context: ClientContext::current(), message: Ok("Resp".into()), }) .await @@ -775,6 +776,7 @@ mod tests { let (tx, mut response) = oneshot::channel(); tx.send(Ok(Response { request_id: 0, + context: ClientContext::current(), message: Ok("well done"), })) .unwrap(); @@ -825,6 +827,7 @@ mod tests { &mut server_channel, Response { request_id: 0, + context: ClientContext::current(), message: Ok("hello".into()), }, ) @@ -1063,7 +1066,7 @@ mod tests { } impl Stream for AlwaysErrorTransport { - type Item = Result, TransportError>; + type Item = Result, TransportError>; fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { if matches!(self.0, TransportError::Read) { Poll::Ready(Some(Err(self.0))) @@ -1079,12 +1082,15 @@ mod tests { RequestDispatch< String, String, - UnboundedChannel, ClientMessage>, + UnboundedChannel< + Response, + ClientMessage, + >, >, >, >, Channel, - UnboundedChannel, Response>, + UnboundedChannel, Response>, ) { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); @@ -1162,8 +1168,11 @@ mod tests { } async fn send_response( - channel: &mut UnboundedChannel, Response>, - response: Response, + channel: &mut UnboundedChannel< + ClientMessage, + Response, + >, + response: Response, ) { channel.send(response).await.unwrap(); } diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 6f7e08fc0..565fe9f89 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -392,13 +392,29 @@ impl RequestName for u64 { /// A response from a server to a client. #[derive(Clone, Debug, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Response { +pub struct Response { /// The ID of the request being responded to. pub request_id: u64, + /// Trace context, deadline, and other cross-cutting concerns. + pub context: Ctx, /// The response body, or an error if the request failed. pub message: Result, } +impl Response { + /// Creates a modified Response by mapping the context using the provided function. + pub fn map_context(self, f: F) -> Response + where + F: FnOnce(Ctx) -> Ctx2, + { + Response { + request_id: self.request_id, + context: f(self.context), + message: self.message, + } + } +} + /// An error indicating the server aborted the request early, e.g., due to request throttling. #[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)] #[error("{kind:?}: {detail}")] diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 87ac89681..649f21022 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -61,7 +61,7 @@ impl Config { /// Returns a channel backed by `transport` and configured with `self`. pub fn channel(self, transport: T) -> BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, { BaseChannel::new(self, transport) } @@ -165,7 +165,7 @@ pub struct BaseChannel { impl BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, { /// Creates a new channel backed by `transport` and configured with `config`. pub fn new(config: Config, transport: T) -> Self { @@ -304,7 +304,10 @@ pub struct TrackedRequest { /// created by [`BaseChannel`]. pub trait Channel where - Self: Transport::Resp>, TrackedRequest<::Req>>, + Self: Transport< + Response::Resp>, + TrackedRequest<::Req>, + >, { /// Type of request item. type Req; @@ -378,7 +381,7 @@ where /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` - fn requests(self) -> Requests + fn requests(self) -> Requests where Self: Sized, { @@ -433,7 +436,7 @@ where impl Stream for BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, { type Item = Result, ChannelError>; @@ -538,9 +541,9 @@ where } } -impl Sink> for BaseChannel +impl Sink> for BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, T::Error: Error, { type Error = ChannelError; @@ -552,7 +555,10 @@ where .map_err(|e| ChannelError::Ready(Arc::new(e))) } - fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { + fn start_send( + mut self: Pin<&mut Self>, + response: Response, + ) -> Result<(), Self::Error> { if let Some(span) = self .in_flight_requests_mut() .remove_request(response.request_id) @@ -593,7 +599,7 @@ impl AsRef for BaseChannel { impl Channel for BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, { type Req = Req; type Resp = Resp; @@ -615,19 +621,19 @@ where /// A stream of requests coming over a channel. `Requests` also drives the sending of responses, so /// it must be continually polled to ensure progress. #[pin_project] -pub struct Requests +pub struct Requests where C: Channel, { #[pin] channel: C, /// Responses waiting to be written to the wire. - pending_responses: mpsc::Receiver>, + pending_responses: mpsc::Receiver>, /// Handed out to request handlers to fan in responses. - responses_tx: mpsc::Sender>, + responses_tx: mpsc::Sender>, } -impl Requests +impl Requests where C: Channel, { @@ -644,7 +650,7 @@ where /// Returns the inner channel over which messages are sent and received. pub fn pending_responses_mut<'a>( self: &'a mut Pin<&mut Self>, - ) -> &'a mut mpsc::Receiver> { + ) -> &'a mut mpsc::Receiver> { self.as_mut().project().pending_responses } @@ -716,7 +722,7 @@ where fn poll_next_response( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll, C::Error>>> { + ) -> Poll, C::Error>>> { ready!(self.ensure_writeable(cx)?); match ready!(self.pending_responses_mut().poll_recv(cx)) { @@ -789,7 +795,7 @@ where } } -impl fmt::Debug for Requests +impl fmt::Debug for Requests where C: Channel, { @@ -825,7 +831,7 @@ pub struct InFlightRequest { abort_registration: AbortRegistration, response_guard: ResponseGuard, span: Span, - response_tx: mpsc::Sender>, + response_tx: mpsc::Sender>, } impl InFlightRequest { @@ -904,6 +910,7 @@ impl InFlightRequest { tracing::debug!("CompleteRequest"); let response = Response { request_id, + context, message, }; let _ = response_tx.send(response).await; @@ -927,7 +934,7 @@ fn print_err(e: &(dyn Error + 'static)) -> String { .join(": ") } -impl Stream for Requests +impl Stream for Requests where C: Channel, { @@ -1002,11 +1009,14 @@ mod tests { BaseChannel< Req, Resp, - UnboundedChannel, Response>, + UnboundedChannel< + ClientMessage, + Response, + >, >, >, >, - UnboundedChannel, ClientMessage>, + UnboundedChannel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::unbounded(); (Box::pin(BaseChannel::new(Config::default(), rx)), tx) @@ -1016,15 +1026,19 @@ mod tests { Pin< Box< Requests< + ServerContext, BaseChannel< Req, Resp, - UnboundedChannel, Response>, + UnboundedChannel< + ClientMessage, + Response, + >, >, >, >, >, - UnboundedChannel, ClientMessage>, + UnboundedChannel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::unbounded(); ( @@ -1039,15 +1053,19 @@ mod tests { Pin< Box< Requests< + ServerContext, BaseChannel< Req, Resp, - channel::Channel, Response>, + channel::Channel< + ClientMessage, + Response, + >, >, >, >, >, - channel::Channel, ClientMessage>, + channel::Channel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::bounded(capacity); // Add 1 because capacity 0 is not supported (but is supported by transport::channel::bounded). @@ -1322,7 +1340,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::ServerContext::current(), + context: ServerContext::current(), message: (), }) .unwrap(); @@ -1331,6 +1349,7 @@ mod tests { .as_mut() .start_send(Response { request_id: 0, + context: ServerContext::current(), message: Ok(()), }) .unwrap(); @@ -1398,6 +1417,7 @@ mod tests { .channel_pin_mut() .start_send(Response { request_id: 0, + context: ServerContext::current(), message: Ok(()), }) .unwrap(); @@ -1409,6 +1429,7 @@ mod tests { .responses_tx .send(Response { request_id: 1, + context: ServerContext::current(), message: Ok(()), }) .await @@ -1419,7 +1440,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 1, - context: context::ServerContext::current(), + context: ServerContext::current(), message: (), }) .unwrap(); @@ -1449,6 +1470,7 @@ mod tests { .channel_pin_mut() .start_send(Response { request_id: 0, + context: ServerContext::current(), message: Ok(()), }) .unwrap(); @@ -1459,7 +1481,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 1, - context: context::ServerContext::current(), + context: ServerContext::current(), message: (), }) .unwrap(); @@ -1469,6 +1491,7 @@ mod tests { .responses_tx .send(Response { request_id: 1, + context: ServerContext::current(), message: Ok(()), }) .await diff --git a/tarpc/src/server/limits/requests_per_channel.rs b/tarpc/src/server/limits/requests_per_channel.rs index bd9c103b0..395ded512 100644 --- a/tarpc/src/server/limits/requests_per_channel.rs +++ b/tarpc/src/server/limits/requests_per_channel.rs @@ -4,6 +4,7 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +use crate::context::ServerContext; use crate::{ Response, ServerError, server::{Channel, Config}, @@ -67,6 +68,7 @@ where self.as_mut().start_send(Response { request_id: r.request.id, + context: r.request.context, message: Err(ServerError { kind: io::ErrorKind::WouldBlock, detail: "server throttled the request.".into(), @@ -80,7 +82,7 @@ where } } -impl Sink::Resp>> for MaxRequests +impl Sink::Resp>> for MaxRequests where C: Channel, { @@ -92,7 +94,7 @@ where fn start_send( self: Pin<&mut Self>, - item: Response<::Resp>, + item: Response::Resp>, ) -> Result<(), Self::Error> { self.project().inner.start_send(item) } @@ -268,7 +270,8 @@ mod tests { } impl PendingSink<(), ()> { pub fn default() - -> PendingSink>, Response> { + -> PendingSink>, Response> + { PendingSink { ghost: PhantomData } } } @@ -293,7 +296,9 @@ mod tests { Poll::Pending } } - impl Channel for PendingSink>, Response> { + impl Channel + for PendingSink>, Response> + { type Req = Req; type Resp = Resp; type Transport = (); @@ -326,16 +331,16 @@ mod tests { .as_mut() .start_send(Response { request_id: 0, + context: ServerContext::current(), message: Ok(1), }) .unwrap(); assert_eq!(throttler.inner.in_flight_requests.len(), 0); - assert_eq!( - throttler.inner.sink.front(), - Some(&Response { - request_id: 0, - message: Ok(1), - }) - ); + + let result = throttler.inner.sink.front(); + + assert_eq!(result.map(|r| r.request_id), Some(0)); + + assert_eq!(result.map(|r| &r.message), Some(&Ok(1))); } } diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index 709167751..63b65d697 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -4,6 +4,7 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +use crate::context::ServerContext; use crate::{ Request, Response, cancellations::{CanceledRequests, RequestCancellation, cancellations}, @@ -38,14 +39,19 @@ where } } -impl Sink> for FakeChannel> { +impl Sink> + for FakeChannel> +{ type Error = io::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { self.project().sink.poll_ready(cx).map_err(|e| match e {}) } - fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { + fn start_send( + mut self: Pin<&mut Self>, + response: Response, + ) -> Result<(), Self::Error> { self.as_mut() .project() .in_flight_requests @@ -65,7 +71,8 @@ impl Sink> for FakeChannel> { } } -impl Channel for FakeChannel>, Response> +impl Channel + for FakeChannel>, Response> where Req: Unpin, { @@ -86,7 +93,7 @@ where } } -impl FakeChannel>, Response> { +impl FakeChannel>, Response> { pub fn push_req(&mut self, id: u64, message: Req) { let (_, abort_registration) = futures::future::AbortHandle::new_pair(); let (request_cancellation, _) = cancellations(); @@ -111,7 +118,8 @@ impl FakeChannel>, Response> { } impl FakeChannel<(), ()> { - pub fn default() -> FakeChannel>, Response> { + pub fn default() + -> FakeChannel>, Response> { let (request_cancellation, canceled_requests) = cancellations(); FakeChannel { stream: Default::default(), diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 65e987d02..7615d8fe1 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -7,7 +7,9 @@ //! Transports backed by in-memory channels. use crate::context::{ClientContext, ServerContext, SharedContext}; -use crate::{ClientMessage, Transport}; +use crate::{ClientMessage, Response, Transport}; +use futures::future::{Ready}; +use futures::sink::With; use futures::{Sink, SinkExt, Stream, TryStreamExt, task::*}; use pin_project::pin_project; use std::{error::Error, future, pin::Pin}; @@ -43,21 +45,40 @@ pub fn unbounded() -> ( /// Returns two mapped unbounded channel peers. Each [`Stream`] yields items sent through the other's /// [`Sink`]. -pub fn unbounded_mapped( +pub fn unbounded_mapped< + SerializedSinkItem, + SerializedItem, + ClientSinkItem, + ServerSinkItem, + ClientItem, + ServerItem, + F, + G, + H, + I, +>( mut f: F, mut g: G, + mut h: H, + mut i: I, ) -> ( - impl Transport, - impl Transport, + impl Transport, + impl Transport, ) where F: FnMut(ClientSinkItem) -> SerializedSinkItem, G: FnMut(SerializedSinkItem) -> ServerSinkItem, + H: FnMut(SerializedItem) -> ClientItem, + I: FnMut(ServerItem) -> SerializedItem, { let (client, server) = unbounded(); - let client = client.with(move |msg: ClientSinkItem| future::ready(Ok(f(msg)))); - let server = server.map_ok(move |msg: SerializedSinkItem| g(msg)); + let client = client + .with(move |msg: ClientSinkItem| future::ready(Ok(f(msg)))) + .map_ok(move |msg: SerializedItem| h(msg)); + let server = server + .map_ok(move |msg: SerializedSinkItem| g(msg)) + .with(move |msg: ServerItem| future::ready(Ok(i(msg)))); (client, server) } @@ -65,26 +86,93 @@ where /// Convenience functino to return two mapped unbounded channel peers for a basechannel and a client implementation. Each [`Stream`] yields items sent through the other's /// [`Sink`]. pub fn unbounded_for_client_server_context() -> ( - impl Transport, Resp>, - impl Transport>, + impl Transport, Response>, + impl Transport, ClientMessage>, ) { - unbounded_mapped(map_client_context_to_shared, map_shared_context_to_server) + unbounded_mapped( + map_req_client_context_to_shared, + map_req_shared_context_to_server, + map_resp_shared_context_to_client, + map_resp_server_context_to_shared, + ) } /// Convenience function to map a ClientMessage with ClientContext to one with SharedContext. -pub fn map_client_context_to_shared( +fn map_req_client_context_to_shared( msg: ClientMessage, ) -> ClientMessage { msg.map_context(|ctx| ctx.shared_context) } /// Convenience function to map a ClientMessage with SharedContext to one with ServerContext. -pub fn map_shared_context_to_server( +fn map_req_shared_context_to_server( msg: ClientMessage, ) -> ClientMessage { msg.map_context(ServerContext::new) } +/// Convenience function to map a ClientMessage with ClientContext to one with SharedContext. +fn map_resp_server_context_to_shared( + resp: Response, +) -> Response { + resp.map_context(|ctx| ctx.shared_context) +} + +/// Convenience function to map a ClientMessage with SharedContext to one with ServerContext. +fn map_resp_shared_context_to_client( + msg: Response, +) -> Response { + msg.map_context(ClientContext::new) +} + +/// TODO: document +/// Yuck, but impl trait will loose our ability to do t.as_ref() +pub fn map_transport_to_client( + t: T, +) -> futures::stream::MapOk< + With< + T, + ClientMessage, + ClientMessage, + Ready, E>>, + fn(ClientMessage) -> Ready, E>>, + >, + fn(Response) -> Response, +> +where + T: Transport, Response>, + E: From +{ + let f: fn(ClientMessage) -> Ready, E>> = |resp| futures::future::ok(map_req_client_context_to_shared(resp)); + + t.with(f).map_ok(map_resp_shared_context_to_client) +} + +/// TODO: document +/// +/// Yuck, but impl trait will loose our ability to do t.as_ref() +pub fn map_transport_to_server( + t: T, +) -> futures::stream::MapOk< + With< + T, + Response, + Response, + Ready, E>>, + fn(Response) -> Ready, E>>, + >, + fn(ClientMessage) -> ClientMessage, +> +where + T: Transport, ClientMessage>, + E: From +{ + let f: fn(Response) -> Ready, E>> = |resp| futures::future::ok(map_resp_server_context_to_shared(resp)); + + t.with(f) + .map_ok(map_req_shared_context_to_server) +} + /// A bi-directional channel backed by an [`UnboundedSender`](mpsc::UnboundedSender) /// and [`UnboundedReceiver`](mpsc::UnboundedReceiver). #[derive(Debug)] diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 1ac04af13..54fadf77d 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -1,6 +1,6 @@ use futures::prelude::*; use tarpc::context::{ClientContext, ServerContext, SharedContext}; -use tarpc::transport::channel::{map_client_context_to_shared, map_shared_context_to_server}; +use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; use tarpc::{ClientMessage, serde_transport}; use tarpc::{ client, context, @@ -45,16 +45,15 @@ async fn test_call() -> anyhow::Result<()> { transport .take(1) .filter_map(|r| async { r.ok() }) - .map(|t| t.map_ok(map_shared_context_to_server)) + .map(map_transport_to_server) .map(BaseChannel::with_defaults) .execute(ColorServer.serve()) .map(|channel| channel.for_each(spawn)) .for_each(spawn), ); - let transport = serde_transport::tcp::connect(addr, Json::default) - .await? - .with(|msg| future::ok(map_client_context_to_shared(msg))); + let transport = serde_transport::tcp::connect(addr, Json::default).await?; + let transport = map_transport_to_client(transport); let client = ColorProtocolClient::new(client::Config::default(), transport).spawn(); diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index ebebef660..b65a66104 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -4,8 +4,7 @@ use futures::{ prelude::*, }; use std::time::{Duration, Instant}; -use tarpc::context::{ClientContext, ServerContext, SharedContext}; -use tarpc::transport::channel::{map_client_context_to_shared, map_shared_context_to_server}; +use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; use tarpc::{ ClientMessage, client::{self}, @@ -116,7 +115,7 @@ async fn serde_tcp() -> anyhow::Result<()> { transport .take(1) .filter_map(|r| async { r.ok() }) - .map(|t| t.map_ok(map_shared_context_to_server)) + .map(map_transport_to_server) .map(BaseChannel::with_defaults) .execute(Server.serve()) .map(|channel| channel.for_each(spawn)) @@ -124,7 +123,7 @@ async fn serde_tcp() -> anyhow::Result<()> { ); let transport = serde_transport::tcp::connect(addr, Json::default).await?; - let transport = transport.with(|msg| future::ok(map_client_context_to_shared(msg))); + let transport = map_transport_to_client(transport); let client = ServiceClient::new(client::Config::default(), transport).spawn(); assert_matches!( @@ -155,7 +154,7 @@ async fn serde_uds() -> anyhow::Result<()> { transport .take(1) .filter_map(|r| async { r.ok() }) - .map(|t| t.map_ok(map_shared_context_to_server)) + .map(map_transport_to_server) .map(BaseChannel::with_defaults) .execute(Server.serve()) .map(|channel| channel.for_each(spawn)) @@ -163,7 +162,7 @@ async fn serde_uds() -> anyhow::Result<()> { ); let transport = serde_transport::unix::connect(&sock, Json::default).await?; - let transport = transport.with(|msg| future::ok(map_client_context_to_shared(msg))); + let transport = map_transport_to_client(transport); let client = ServiceClient::new(client::Config::default(), transport).spawn(); From 54b8fe8a0fd6b32e305bed89d4dfa529e384ea6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Tue, 25 Nov 2025 16:05:33 +0100 Subject: [PATCH 10/26] allow server to mutate shared context --- tarpc/src/client.rs | 90 +++++++++++++------------- tarpc/src/client/in_flight_requests.rs | 19 +++--- 2 files changed, 53 insertions(+), 56 deletions(-) diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 125f3ad4a..ee85d842d 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -166,14 +166,19 @@ where }) .await .map_err(|mpsc::error::SendError(_)| RpcError::Shutdown)?; - response_guard.response().await + + let (response_ctx, r) = response_guard.response().await?; + + ctx.shared_context = response_ctx.shared_context; + + Ok(r) } } /// A server response that is completed by request dispatch when the corresponding response /// arrives off the wire. struct ResponseGuard<'a, Resp> { - response: &'a mut oneshot::Receiver>, + response: &'a mut oneshot::Receiver>, cancellation: &'a RequestCancellation, request_id: u64, cancel: bool, @@ -201,7 +206,7 @@ pub enum RpcError { } impl ResponseGuard<'_, Resp> { - async fn response(mut self) -> Result { + async fn response(mut self) -> Result<(ClientContext, Resp), RpcError> { let response = (&mut self.response).await; // Cancel drop logic once a response has been received. self.cancel = false; @@ -280,7 +285,7 @@ pub struct RequestDispatch { /// Requests that were dropped. canceled_requests: CanceledRequests, /// Requests already written to the wire that haven't yet received responses. - in_flight_requests: InFlightRequests>, + in_flight_requests: InFlightRequests, /// Configures limits to prevent unlimited resource usage. config: Config, /// Produces errors that can be sent in response to any unprocessed requests at the time @@ -296,7 +301,7 @@ where { fn in_flight_requests<'a>( self: &'a mut Pin<&mut Self>, - ) -> &'a mut InFlightRequests> { + ) -> &'a mut InFlightRequests { self.as_mut().project().in_flight_requests } @@ -522,12 +527,10 @@ where let trace_context = ctx.trace_context; let deadline = ctx.deadline; - let client_context = context::ClientContext::new(ctx); - let request = ClientMessage::Request(Request { id: request_id, message: request, - context: client_context, + context: ClientContext::new(ctx), }); self.in_flight_requests() @@ -580,7 +583,7 @@ where fn complete(mut self: Pin<&mut Self>, response: Response) -> bool { if let Some(span) = self.in_flight_requests().complete_request( response.request_id, - response.message.map_err(RpcError::Server), + response.message.map_err(RpcError::Server).map(|m| (response.context, m)), ) { let _entered = span.enter(); tracing::debug!("ReceiveResponse"); @@ -688,11 +691,11 @@ where /// the lifecycle of the request. #[derive(Debug)] struct DispatchRequest { - pub ctx: context::SharedContext, + pub ctx: context::SharedContextg, ///TODO: <-- this should be a &mut ClientContext pub span: Span, pub request_id: u64, pub request: Req, - pub response_completion: oneshot::Sender>, + pub response_completion: oneshot::Sender>, } #[cfg(test)] @@ -752,7 +755,7 @@ mod tests { .await .unwrap(); assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending); - assert_matches!(rx.try_recv(), Ok(Ok(resp)) if resp == "Resp"); + assert_matches!(rx.try_recv(), Ok(Ok((_, resp))) if resp == "Resp"); } #[tokio::test] @@ -774,12 +777,7 @@ mod tests { async fn dispatch_response_doesnt_cancel_after_complete() { let (cancellation, mut canceled_requests) = cancellations(); let (tx, mut response) = oneshot::channel(); - tx.send(Ok(Response { - request_id: 0, - context: ClientContext::current(), - message: Ok("well done"), - })) - .unwrap(); + tx.send(Ok((ClientContext::current(), "well done"))).unwrap(); // resp's drop() is run, but should not send a cancel message. ResponseGuard { response: &mut response, @@ -1116,37 +1114,11 @@ mod tests { (Box::pin(dispatch), channel, server_channel) } - async fn reserve_for_send<'a>( - channel: &'a mut Channel, - response_completion: oneshot::Sender>, - response: &'a mut oneshot::Receiver>, - ) -> impl FnOnce(&str) -> ResponseGuard<'a, String> { - let permit = channel.to_dispatch.reserve().await.unwrap(); - |request| { - let request_id = - u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); - let request = DispatchRequest { - ctx: SharedContext::current(), - span: Span::current(), - request_id, - request: request.to_string(), - response_completion, - }; - permit.send(request); - ResponseGuard { - response, - cancellation: &channel.cancellation, - request_id, - cancel: true, - } - } - } - async fn send_request<'a>( channel: &'a mut Channel, request: &str, - response_completion: oneshot::Sender>, - response: &'a mut oneshot::Receiver>, + response_completion: oneshot::Sender>, + response: &'a mut oneshot::Receiver>, ) -> ResponseGuard<'a, String> { let request_id = u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); @@ -1167,6 +1139,32 @@ mod tests { response_guard } + async fn reserve_for_send<'a>( + channel: &'a mut Channel, + response_completion: oneshot::Sender>, + response: &'a mut oneshot::Receiver>, + ) -> impl FnOnce(&str) -> ResponseGuard<'a, String> { + let permit = channel.to_dispatch.reserve().await.unwrap(); + |request| { + let request_id = + u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); + let request = DispatchRequest { + ctx: SharedContext::current(), + span: Span::current(), + request_id, + request: request.to_string(), + response_completion, + }; + permit.send(request); + ResponseGuard { + response, + cancellation: &channel.cancellation, + request_id, + cancel: true, + } + } + } + async fn send_response( channel: &mut UnboundedChannel< ClientMessage, diff --git a/tarpc/src/client/in_flight_requests.rs b/tarpc/src/client/in_flight_requests.rs index 7a554de27..90f60c527 100644 --- a/tarpc/src/client/in_flight_requests.rs +++ b/tarpc/src/client/in_flight_requests.rs @@ -1,7 +1,4 @@ -use crate::{ - trace, - util::{Compact, TimeUntil}, -}; +use crate::{trace, util::{Compact, TimeUntil}}; use fnv::FnvHashMap; use std::time::Instant; use std::{ @@ -11,6 +8,8 @@ use std::{ use tokio::sync::oneshot; use tokio_util::time::delay_queue::{self, DelayQueue}; use tracing::Span; +use crate::client::RpcError; +use crate::context::ClientContext; /// Requests already written to the wire that haven't yet received responses. #[derive(Debug)] @@ -32,7 +31,7 @@ impl Default for InFlightRequests { struct RequestData { ctx: trace::Context, span: Span, - response_completion: oneshot::Sender, + response_completion: oneshot::Sender>, /// The key to remove the timer for the request's deadline. deadline_key: delay_queue::Key, } @@ -60,7 +59,7 @@ impl InFlightRequests { ctx: trace::Context, deadline: Instant, span: Span, - response_completion: oneshot::Sender, + response_completion: oneshot::Sender>, ) -> Result<(), AlreadyExistsError> { match self.request_data.entry(request_id) { hash_map::Entry::Vacant(vacant) => { @@ -78,8 +77,8 @@ impl InFlightRequests { } } - /// Removes a request without aborting. Returns true iff the request was found. - pub fn complete_request(&mut self, request_id: u64, result: Res) -> Option { + /// Removes a request without aborting. Returns true if the request was found. + pub fn complete_request(&mut self, request_id: u64, result: Result<(ClientContext, Res), RpcError>) -> Option { if let Some(request_data) = self.request_data.remove(&request_id) { self.request_data.compact(0.1); self.deadlines.remove(&request_data.deadline_key); @@ -97,7 +96,7 @@ impl InFlightRequests { /// Returns Spans for all completes requests. pub fn complete_all_requests<'a>( &'a mut self, - mut result: impl FnMut() -> Res + 'a, + mut result: impl FnMut() -> Result<(ClientContext, Res), RpcError> + 'a, ) -> impl Iterator + 'a { self.deadlines.clear(); self.request_data.drain().map(move |(_, request_data)| { @@ -123,7 +122,7 @@ impl InFlightRequests { pub fn poll_expired( &mut self, cx: &mut Context, - expired_error: impl Fn() -> Res, + expired_error: impl Fn() -> Result<(ClientContext, Res), RpcError>, ) -> Poll> { self.deadlines.poll_expired(cx).map(|expired| { let request_id = expired?.into_inner(); From 6ded9edff5f4dd391a8cc5a3702ec08d391f1f43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Tue, 25 Nov 2025 16:10:30 +0100 Subject: [PATCH 11/26] fix typo --- tarpc/src/client.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index ee85d842d..a42c94491 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -691,7 +691,7 @@ where /// the lifecycle of the request. #[derive(Debug)] struct DispatchRequest { - pub ctx: context::SharedContextg, ///TODO: <-- this should be a &mut ClientContext + pub ctx: context::SharedContext, ///TODO: <-- this should be a &mut ClientContext pub span: Span, pub request_id: u64, pub request: Req, From 6fa12926de7dc54eb38799817c744b187e556b6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Wed, 26 Nov 2025 13:09:55 +0100 Subject: [PATCH 12/26] make servertransport generic, defined by the service implementation. --- example-service/src/server.rs | 1 + plugins/src/lib.rs | 17 +- plugins/tests/service.rs | 4 + tarpc/Cargo.toml | 2 + tarpc/examples/compression.rs | 2 + tarpc/examples/custom_transport.rs | 2 + tarpc/examples/pubsub.rs | 2 + tarpc/examples/readme.rs | 1 + tarpc/examples/tls_over_tcp.rs | 1 + tarpc/examples/tracing.rs | 2 + tarpc/src/client.rs | 37 +-- tarpc/src/client/in_flight_requests.rs | 12 +- tarpc/src/client/stub.rs | 14 +- tarpc/src/client/stub/load_balance.rs | 10 +- tarpc/src/client/stub/mock.rs | 13 +- tarpc/src/client/stub/retry.rs | 4 +- tarpc/src/context.rs | 43 +++- tarpc/src/lib.rs | 8 +- tarpc/src/server.rs | 219 ++++++++++-------- tarpc/src/server/incoming.rs | 2 +- tarpc/src/server/limits/channels_per_key.rs | 1 + .../src/server/limits/requests_per_channel.rs | 14 +- tarpc/src/server/request_hook.rs | 18 +- tarpc/src/server/request_hook/after.rs | 17 +- tarpc/src/server/request_hook/before.rs | 82 ++++--- .../server/request_hook/before_and_after.rs | 19 +- tarpc/src/server/testing.rs | 24 +- tarpc/tests/dataservice.rs | 1 + tarpc/tests/service_functional.rs | 12 +- 29 files changed, 347 insertions(+), 237 deletions(-) diff --git a/example-service/src/server.rs b/example-service/src/server.rs index fe61904b9..5e176dfa2 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -38,6 +38,7 @@ struct Flags { struct HelloServer(SocketAddr); impl World for HelloServer { + type Context = ServerContext; async fn hello(self, _: &mut context::ServerContext, name: String) -> String { let sleep_time = Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng())); diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 71d7d3c80..432b2f1c8 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -402,7 +402,8 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// #[derive(Clone)] /// struct CalculatorServer; /// impl Calculator for CalculatorServer { -/// async fn add(self, context: &mut ServerContext, a: i32, b: i32) -> i32 { +/// type Context = ServerContext; +/// async fn add(self, context: &mut Self::Context, a: i32, b: i32) -> i32 { /// a + b /// } /// } @@ -559,7 +560,7 @@ impl ServiceGenerator<'_> { )| { quote! { #( #attrs )* - async fn #ident(self, context: &mut ::tarpc::context::ServerContext, #( #args ),*) -> #output; + async fn #ident(self, context: &mut Self::Context, #( #args ),*) -> #output; } }, ); @@ -568,6 +569,8 @@ impl ServiceGenerator<'_> { quote! { #( #attrs )* #vis trait #service_ident: ::core::marker::Sized { + type Context: ::tarpc::context::ExtractContext<::tarpc::context::SharedContext>; + #( #rpc_fns )* /// Returns a serving function to use with @@ -578,11 +581,11 @@ impl ServiceGenerator<'_> { } #[doc = #stub_doc] - #vis trait #client_stub_ident: ::tarpc::client::stub::Stub { + #vis trait #client_stub_ident: ::tarpc::client::stub::Stub { } impl #client_stub_ident for S - where S: ::tarpc::client::stub::Stub + where S: ::tarpc::client::stub::Stub { } } @@ -621,9 +624,9 @@ impl ServiceGenerator<'_> { { type Req = #request_ident; type Resp = #response_ident; + type ServerCtx = S::Context; - - async fn serve(self, ctx: &mut ::tarpc::context::ServerContext, req: #request_ident) + async fn serve(self, ctx: &mut Self::ServerCtx, req: #request_ident) -> ::core::result::Result<#response_ident, ::tarpc::ServerError> { match req { #( @@ -787,7 +790,7 @@ impl ServiceGenerator<'_> { #( #[allow(unused)] #( #method_attrs )* - #vis fn #method_idents<'a>(&'a self, ctx: &'a mut ::tarpc::context::ClientContext, #( #args ),*) + #vis fn #method_idents<'a>(&'a self, ctx: &'a mut Stub::ServerCtx, #( #args ),*) -> impl ::core::future::Future> + '_ { let request = #request_ident::#camel_case_idents { #( #arg_pats ),* }; let resp = self.0.call(ctx, request); diff --git a/plugins/tests/service.rs b/plugins/tests/service.rs index 756766621..ef49b9666 100644 --- a/plugins/tests/service.rs +++ b/plugins/tests/service.rs @@ -1,6 +1,7 @@ use serde::{Deserialize, Serialize}; use std::hash::Hash; use tarpc::context; +use tarpc::context::ServerContext; #[test] fn att_service_trait() { @@ -12,6 +13,7 @@ fn att_service_trait() { } impl Foo for () { + type Context = ServerContext; async fn two_part( self, _: &mut context::ServerContext, @@ -42,6 +44,7 @@ fn raw_idents() { } impl r#trait for () { + type Context = ServerContext; async fn r#await( self, _: &mut context::ServerContext, @@ -69,6 +72,7 @@ fn service_with_cfg_rpc() { } impl Foo for () { + type Context = ServerContext; async fn foo(self, _: &mut context::ServerContext) {} } } diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index 778eb0938..0a5efc137 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -61,6 +61,8 @@ tracing = { version = "0.1", default-features = false, features = [ tracing-opentelemetry = { version = "0.31.0", default-features = false } opentelemetry = { version = "0.30.0", default-features = false } opentelemetry-semantic-conventions = "0.30.0" +anymap3 = "1.0.1" +serde-value = "0.7" [dev-dependencies] assert_matches = "1.4" diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index 0801ce9f4..c00ffc9f3 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -16,6 +16,7 @@ use tarpc::{ server::{BaseChannel, Channel}, tokio_serde::formats::Bincode, }; +use tarpc::context::ServerContext; /// Type of compression that should be enabled on the request. The transport is free to ignore this. #[derive(Debug, PartialEq, Eq, Clone, Copy, Deserialize, Serialize)] @@ -109,6 +110,7 @@ pub trait World { struct HelloServer; impl World for HelloServer { + type Context = ServerContext; async fn hello(self, _: &mut context::ServerContext, name: String) -> String { format!("Hey, {name}!") } diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index 7c23a1fa7..415cb5442 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -4,6 +4,7 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +use console_subscriber::Server; use futures::prelude::*; use tarpc::context::{ClientContext, ServerContext, SharedContext}; use tarpc::serde_transport as transport; @@ -22,6 +23,7 @@ pub trait PingService { struct Service; impl PingService for Service { + type Context = ServerContext; async fn ping(self, _: &mut ServerContext) {} } #[tokio::main] diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 8094c490d..6755e49ca 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -82,6 +82,7 @@ struct Subscriber { } impl subscriber::Subscriber for Subscriber { + type Context = ServerContext; async fn topics(self, _: &mut context::ServerContext) -> Vec { self.topics.clone() } @@ -271,6 +272,7 @@ impl Publisher { } impl publisher::Publisher for Publisher { + type Context = ServerContext; async fn publish(self, _: &mut context::ServerContext, topic: String, message: String) { info!("received message to publish."); let mut subscribers = match self.subscriptions.read().unwrap().get(&topic) { diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index b20d4ab91..c7e8de00b 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -25,6 +25,7 @@ pub trait World { struct HelloServer; impl World for HelloServer { + type Context = ServerContext; async fn hello(self, _: &mut context::ServerContext, name: String) -> String { format!("Hello, {name}!") } diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index 2d90650a5..4ed3298bb 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -33,6 +33,7 @@ pub trait PingService { struct Service; impl PingService for Service { + type Context = ServerContext; async fn ping(self, _: &mut ServerContext) -> String { "🔒".to_owned() } diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index 0930aae1d..f747c9d75 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -58,6 +58,7 @@ pub mod double { struct AddServer; impl AddService for AddServer { + type Context = ServerContext; async fn add(self, _: &mut context::ServerContext, x: i32, y: i32) -> i32 { x + y } @@ -72,6 +73,7 @@ impl DoubleService for DoubleServer where Stub: AddStub + Clone + Send + Sync + 'static, { + type Context = ServerContext; async fn double(self, _: &mut context::ServerContext, x: i32) -> Result { self.add_client .add(&mut context::ClientContext::current(), x, x) diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index a42c94491..b6763a9b2 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -9,7 +9,7 @@ mod in_flight_requests; pub mod stub; -use crate::context::ClientContext; +use crate::context::{ClientContext, ExtractContext, SharedContext}; use crate::{ ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport, cancellations::{CanceledRequests, RequestCancellation, cancellations}, @@ -125,23 +125,24 @@ where skip(self, ctx, request), fields( rpc.trace_id = tracing::field::Empty, - rpc.deadline = %humantime::format_rfc3339(SystemTime::now() + ctx.deadline.time_until()), + rpc.deadline = %humantime::format_rfc3339(SystemTime::now() + ctx.extract().deadline.time_until()), otel.kind = "client", otel.name = %request.name()) )] - pub async fn call( + pub async fn call>( &self, - ctx: &mut context::ClientContext, + ctx: &mut Ctx, request: Req, ) -> Result { let span = Span::current(); - ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { + let mut shared_context = ctx.extract(); + shared_context.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { tracing::trace!( "OpenTelemetry subscriber not installed; making unsampled child context." ); - ctx.trace_context.new_child() + shared_context.trace_context.new_child() }); - span.record("rpc.trace_id", tracing::field::display(ctx.trace_id())); + span.record("rpc.trace_id", tracing::field::display(shared_context.trace_id())); let (response_completion, mut response) = oneshot::channel(); let request_id = u64::try_from(self.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); @@ -158,7 +159,7 @@ where }; self.to_dispatch .send(DispatchRequest { - ctx: ctx.shared_context.clone(), + ctx: shared_context, span, request_id, request, @@ -169,7 +170,7 @@ where let (response_ctx, r) = response_guard.response().await?; - ctx.shared_context = response_ctx.shared_context; + ctx.update(response_ctx); Ok(r) } @@ -178,7 +179,7 @@ where /// A server response that is completed by request dispatch when the corresponding response /// arrives off the wire. struct ResponseGuard<'a, Resp> { - response: &'a mut oneshot::Receiver>, + response: &'a mut oneshot::Receiver>, cancellation: &'a RequestCancellation, request_id: u64, cancel: bool, @@ -206,7 +207,7 @@ pub enum RpcError { } impl ResponseGuard<'_, Resp> { - async fn response(mut self) -> Result<(ClientContext, Resp), RpcError> { + async fn response(mut self) -> Result<(SharedContext, Resp), RpcError> { let response = (&mut self.response).await; // Cancel drop logic once a response has been received. self.cancel = false; @@ -583,7 +584,7 @@ where fn complete(mut self: Pin<&mut Self>, response: Response) -> bool { if let Some(span) = self.in_flight_requests().complete_request( response.request_id, - response.message.map_err(RpcError::Server).map(|m| (response.context, m)), + response.message.map_err(RpcError::Server).map(|m| (response.context.shared_context, m)), ) { let _entered = span.enter(); tracing::debug!("ReceiveResponse"); @@ -695,7 +696,7 @@ struct DispatchRequest { pub span: Span, pub request_id: u64, pub request: Req, - pub response_completion: oneshot::Sender>, + pub response_completion: oneshot::Sender>, } #[cfg(test)] @@ -777,7 +778,7 @@ mod tests { async fn dispatch_response_doesnt_cancel_after_complete() { let (cancellation, mut canceled_requests) = cancellations(); let (tx, mut response) = oneshot::channel(); - tx.send(Ok((ClientContext::current(), "well done"))).unwrap(); + tx.send(Ok((SharedContext::current(), "well done"))).unwrap(); // resp's drop() is run, but should not send a cancel message. ResponseGuard { response: &mut response, @@ -1117,8 +1118,8 @@ mod tests { async fn send_request<'a>( channel: &'a mut Channel, request: &str, - response_completion: oneshot::Sender>, - response: &'a mut oneshot::Receiver>, + response_completion: oneshot::Sender>, + response: &'a mut oneshot::Receiver>, ) -> ResponseGuard<'a, String> { let request_id = u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); @@ -1141,8 +1142,8 @@ mod tests { async fn reserve_for_send<'a>( channel: &'a mut Channel, - response_completion: oneshot::Sender>, - response: &'a mut oneshot::Receiver>, + response_completion: oneshot::Sender>, + response: &'a mut oneshot::Receiver>, ) -> impl FnOnce(&str) -> ResponseGuard<'a, String> { let permit = channel.to_dispatch.reserve().await.unwrap(); |request| { diff --git a/tarpc/src/client/in_flight_requests.rs b/tarpc/src/client/in_flight_requests.rs index 90f60c527..5b648098b 100644 --- a/tarpc/src/client/in_flight_requests.rs +++ b/tarpc/src/client/in_flight_requests.rs @@ -9,7 +9,7 @@ use tokio::sync::oneshot; use tokio_util::time::delay_queue::{self, DelayQueue}; use tracing::Span; use crate::client::RpcError; -use crate::context::ClientContext; +use crate::context::{SharedContext}; /// Requests already written to the wire that haven't yet received responses. #[derive(Debug)] @@ -31,7 +31,7 @@ impl Default for InFlightRequests { struct RequestData { ctx: trace::Context, span: Span, - response_completion: oneshot::Sender>, + response_completion: oneshot::Sender>, /// The key to remove the timer for the request's deadline. deadline_key: delay_queue::Key, } @@ -59,7 +59,7 @@ impl InFlightRequests { ctx: trace::Context, deadline: Instant, span: Span, - response_completion: oneshot::Sender>, + response_completion: oneshot::Sender>, ) -> Result<(), AlreadyExistsError> { match self.request_data.entry(request_id) { hash_map::Entry::Vacant(vacant) => { @@ -78,7 +78,7 @@ impl InFlightRequests { } /// Removes a request without aborting. Returns true if the request was found. - pub fn complete_request(&mut self, request_id: u64, result: Result<(ClientContext, Res), RpcError>) -> Option { + pub fn complete_request(&mut self, request_id: u64, result: Result<(SharedContext, Res), RpcError>) -> Option { if let Some(request_data) = self.request_data.remove(&request_id) { self.request_data.compact(0.1); self.deadlines.remove(&request_data.deadline_key); @@ -96,7 +96,7 @@ impl InFlightRequests { /// Returns Spans for all completes requests. pub fn complete_all_requests<'a>( &'a mut self, - mut result: impl FnMut() -> Result<(ClientContext, Res), RpcError> + 'a, + mut result: impl FnMut() -> Result<(SharedContext, Res), RpcError> + 'a, ) -> impl Iterator + 'a { self.deadlines.clear(); self.request_data.drain().map(move |(_, request_data)| { @@ -122,7 +122,7 @@ impl InFlightRequests { pub fn poll_expired( &mut self, cx: &mut Context, - expired_error: impl Fn() -> Result<(ClientContext, Res), RpcError>, + expired_error: impl Fn() -> Result<(SharedContext, Res), RpcError>, ) -> Poll> { self.deadlines.poll_expired(cx).map(|expired| { let request_id = expired?.into_inner(); diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index b99f8e42c..6fa159dd7 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -6,6 +6,7 @@ use crate::{ context, server::Serve, }; +use crate::context::{ClientContext, ServerContext}; pub mod load_balance; pub mod retry; @@ -23,10 +24,13 @@ pub trait Stub { /// The service response type. type Resp; + ///TODO: document + type ServerCtx; + /// Calls a remote service. async fn call( &self, - ctx: &mut context::ClientContext, + ctx: &mut Self::ServerCtx, request: Self::Req, ) -> Result; } @@ -37,10 +41,11 @@ where { type Req = Req; type Resp = Resp; + type ServerCtx = ClientContext; async fn call( &self, - ctx: &mut context::ClientContext, + ctx: &mut Self::ServerCtx, request: Req, ) -> Result { Self::call(self, ctx, request).await @@ -49,13 +54,14 @@ where impl Stub for S where - S: Serve + Clone, + S: Serve + Clone, { type Req = S::Req; type Resp = S::Resp; + type ServerCtx = ClientContext; async fn call( &self, - ctx: &mut context::ClientContext, + ctx: &mut ClientContext, req: Self::Req, ) -> Result { let mut server_ctx = context::ServerContext::new(ctx.shared_context.clone()); diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs index 62c8bf677..5b319c6c8 100644 --- a/tarpc/src/client/stub/load_balance.rs +++ b/tarpc/src/client/stub/load_balance.rs @@ -7,7 +7,6 @@ pub use round_robin::RoundRobin; mod round_robin { use crate::{ client::{RpcError, stub}, - context, }; use cycle::AtomicCycle; @@ -17,10 +16,11 @@ mod round_robin { { type Req = Stub::Req; type Resp = Stub::Resp; + type ServerCtx = Stub::ServerCtx; async fn call( &self, - ctx: &mut context::ClientContext, + ctx: &mut Self::ServerCtx, request: Self::Req, ) -> Result { let next = self.stubs.next(); @@ -99,8 +99,7 @@ mod round_robin { /// the same stub. mod consistent_hash { use crate::{ - client::{RpcError, stub}, - context, + client::{RpcError, stub} }; use std::{ collections::hash_map::RandomState, @@ -116,10 +115,11 @@ mod consistent_hash { { type Req = Stub::Req; type Resp = Stub::Resp; + type ServerCtx = Stub::ServerCtx; async fn call( &self, - ctx: &mut context::ClientContext, + ctx: &mut Self::ServerCtx, request: Self::Req, ) -> Result { let index = usize::try_from(self.hasher.hash_one(&request) % self.stubs_len).expect( diff --git a/tarpc/src/client/stub/mock.rs b/tarpc/src/client/stub/mock.rs index bebd8fc99..9a22d101e 100644 --- a/tarpc/src/client/stub/mock.rs +++ b/tarpc/src/client/stub/mock.rs @@ -1,16 +1,17 @@ use crate::{ RequestName, ServerError, client::{RpcError, stub::Stub}, - context, }; use std::{collections::HashMap, hash::Hash, io}; +use std::marker::PhantomData; /// A mock stub that returns user-specified responses. -pub struct Mock { +pub struct Mock { responses: HashMap, + ghost: PhantomData } -impl Mock +impl Mock where Req: Eq + Hash, { @@ -18,21 +19,23 @@ where pub fn new(responses: [(Req, Resp); N]) -> Self { Self { responses: HashMap::from(responses), + ghost: PhantomData } } } -impl Stub for Mock +impl Stub for Mock where Req: Eq + Hash + RequestName, Resp: Clone, { type Req = Req; type Resp = Resp; + type ServerCtx = ServerCtx; async fn call( &self, - _: &mut context::ClientContext, + _: &mut Self::ServerCtx, request: Self::Req, ) -> Result { self.responses diff --git a/tarpc/src/client/stub/retry.rs b/tarpc/src/client/stub/retry.rs index d93daa156..2cf950aed 100644 --- a/tarpc/src/client/stub/retry.rs +++ b/tarpc/src/client/stub/retry.rs @@ -3,7 +3,6 @@ use crate::{ RequestName, client::{RpcError, stub}, - context, }; use std::sync::Arc; @@ -15,10 +14,11 @@ where { type Req = Req; type Resp = Stub::Resp; + type ServerCtx = Stub::ServerCtx; async fn call( &self, - ctx: &mut context::ClientContext, + ctx: &mut Self::ServerCtx, request: Self::Req, ) -> Result { let request = Arc::new(request); diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index bbbc3721d..798044c93 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -23,7 +23,6 @@ use tracing_opentelemetry::OpenTelemetrySpanExt; /// The context should not be stored directly in a server implementation, because the context will /// be different for each request in scope. #[derive(Debug, Clone)] -#[non_exhaustive] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct SharedContext { /// When the client expects the request to be complete by. The server should cancel the request @@ -36,7 +35,25 @@ pub struct SharedContext { /// When a service handles a request by making requests itself, those requests should /// include the same `trace_id` as that included on the original request. This way, /// users can trace related actions across a distributed system. - pub trace_context: trace::Context, + pub trace_context: trace::Context +} + +///TODO +pub trait ExtractContext { + ///TODO + fn extract(&self) -> Ctx; + ///TODO + fn update(&mut self, value: Ctx); +} + +impl ExtractContext for T where T: Clone { + fn extract(&self) -> T { + self.clone() + } + + fn update(&mut self, value: T) { + *self = value + } } /// Request context that carries request-scoped server side information like deadlines and trace information @@ -100,6 +117,28 @@ impl ClientContext { } } +impl ExtractContext for ClientContext { + fn extract(&self) -> SharedContext { + self.shared_context.clone() + } + + fn update(&mut self, value: SharedContext) { + self.shared_context = value + } +} + +impl ExtractContext for ServerContext { + fn extract(&self) -> SharedContext { + self.shared_context.clone() + } + + fn update(&mut self, value: SharedContext) { + self.shared_context = value + } +} + + + impl Deref for ClientContext { type Target = SharedContext; diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 565fe9f89..0578a392f 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -124,8 +124,9 @@ //! struct HelloServer; //! //! impl World for HelloServer { +//! type Context = context::ServerContext; //! // Each defined rpc generates an async fn that serves the RPC -//! async fn hello(self, _: &mut context::ServerContext, name: String) -> String { +//! async fn hello(self, _: &mut Self::Context, name: String) -> String { //! format!("Hello, {name}!") //! } //! } @@ -160,8 +161,9 @@ //! # #[derive(Clone)] //! # struct HelloServer; //! # impl World for HelloServer { -//! // Each defined rpc generates an async fn that serves the RPC -//! # async fn hello(self, _: &mut context::ServerContext, name: String) -> String { +//! # type Context = ServerContext; +//! # // Each defined rpc generates an async fn that serves the RPC +//! # async fn hello(self, _: &mut Self::Context, name: String) -> String { //! # format!("Hello, {name}!") //! # } //! # } diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 649f21022..fe7440f7e 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -6,11 +6,10 @@ //! Provides a server that concurrently handles many connections sending multiplexed requests. -use crate::context::ServerContext; use crate::{ ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport, cancellations::{CanceledRequests, RequestCancellation, cancellations}, - context::{self, SpanExt}, + context::{SpanExt}, trace, util::TimeUntil, }; @@ -28,6 +27,7 @@ use std::{ convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin, sync::Arc, time::SystemTime, }; use tracing::{Span, info_span, instrument::Instrument}; +use crate::context::{ExtractContext, SharedContext}; mod in_flight_requests; pub mod request_hook; @@ -59,9 +59,10 @@ impl Default for Config { impl Config { /// Returns a channel backed by `transport` and configured with `self`. - pub fn channel(self, transport: T) -> BaseChannel + pub fn channel(self, transport: T) -> BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, + ServerCtx: ExtractContext, { BaseChannel::new(self, transport) } @@ -70,6 +71,9 @@ impl Config { /// Equivalent to a `FnOnce(Req) -> impl Future`. #[allow(async_fn_in_trait)] pub trait Serve { + ///TODO document + type ServerCtx; + /// Type of request. type Req: RequestName; @@ -79,19 +83,19 @@ pub trait Serve { /// Responds to a single request. async fn serve( self, - ctx: &mut context::ServerContext, + ctx: &mut Self::ServerCtx, req: Self::Req, ) -> Result; } /// A Serve wrapper around a Fn. #[derive(Debug)] -pub struct ServeFn { +pub struct ServeFn { f: F, - data: PhantomData Resp>, + data: PhantomData<(Req, Resp, ServerCtx)>, } -impl Clone for ServeFn +impl Clone for ServeFn where F: Clone, { @@ -103,16 +107,13 @@ where } } -impl Copy for ServeFn where F: Copy {} +impl Copy for ServeFn where F: Copy {} /// Creates a [`Serve`] wrapper around a `FnOnce(context::Context, Req) -> impl Future>`. -pub fn serve(f: F) -> ServeFn +pub fn serve(f: F) -> ServeFn where - for<'a> F: FnOnce( - &'a mut context::ServerContext, - Req, - ) -> Pin> + 'a + Send>>, + for<'a> F: FnOnce(&'a mut ServerCtx, Req) -> Pin> + 'a + Send>>, { ServeFn { f, @@ -120,18 +121,19 @@ where } } -impl Serve for ServeFn +impl Serve for ServeFn where Req: RequestName, for<'a> F: FnOnce( - &'a mut context::ServerContext, + &'a mut ServerCtx, Req, ) -> Pin> + 'a + Send>>, { + type ServerCtx = ServerCtx; type Req = Req; type Resp = Resp; - async fn serve(self, ctx: &mut context::ServerContext, req: Req) -> Result { + async fn serve(self, ctx: &mut ServerCtx, req: Req) -> Result { (self.f)(ctx, req).await } } @@ -147,7 +149,7 @@ where /// messages. Instead, it internally handles them by cancelling corresponding requests (removing /// the corresponding in-flight requests and aborting their handlers). #[pin_project] -pub struct BaseChannel { +pub struct BaseChannel { config: Config, /// Writes responses to the wire and reads requests off the wire. #[pin] @@ -160,12 +162,13 @@ pub struct BaseChannel { /// Holds data necessary to clean up in-flight requests. in_flight_requests: InFlightRequests, /// Types the request and response. - ghost: PhantomData<(fn() -> Req, fn(Resp))>, + ghost: PhantomData<(Req, Resp, ServeCtx)>, } -impl BaseChannel +impl BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, + ServerCtx: ExtractContext { /// Creates a new channel backed by `transport` and configured with `config`. pub fn new(config: Config, transport: T) -> Self { @@ -211,28 +214,29 @@ where fn start_request( mut self: Pin<&mut Self>, - mut request: Request, - ) -> Result, AlreadyExistsError> { + request: Request, + ) -> Result, AlreadyExistsError> { + let mut shared_context = request.context.extract(); let span = info_span!( "RPC", - rpc.trace_id = %request.context.trace_id(), - rpc.deadline = %humantime::format_rfc3339(SystemTime::now() + request.context.deadline.time_until()), + rpc.trace_id = %shared_context.trace_id(), + rpc.deadline = %humantime::format_rfc3339(SystemTime::now() + shared_context.deadline.time_until()), otel.kind = "server", otel.name = tracing::field::Empty, ); - span.set_context(&request.context); - request.context.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { + span.set_context(&shared_context); + shared_context.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { tracing::trace!( "OpenTelemetry subscriber not installed; making unsampled \ child context." ); - request.context.trace_context.new_child() + shared_context.trace_context.new_child() }); let entered = span.enter(); tracing::debug!("ReceiveRequest"); let start = self.in_flight_requests_mut().start_request( request.id, - request.context.deadline, + shared_context.deadline, span.clone(), ); match start { @@ -257,7 +261,7 @@ where } } -impl fmt::Debug for BaseChannel { +impl fmt::Debug for BaseChannel { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "BaseChannel") } @@ -265,9 +269,9 @@ impl fmt::Debug for BaseChannel { /// A request tracked by a [`Channel`]. #[derive(Debug)] -pub struct TrackedRequest { +pub struct TrackedRequest { /// The request sent by the client. - pub request: Request, + pub request: Request, /// A registration to abort a future when the [`Channel`] that produced this request stops /// tracking it. pub abort_registration: AbortRegistration, @@ -305,8 +309,8 @@ pub struct TrackedRequest { pub trait Channel where Self: Transport< - Response::Resp>, - TrackedRequest<::Req>, + Response::Resp>, + TrackedRequest::Req>, >, { /// Type of request item. @@ -317,6 +321,8 @@ where /// The wrapped transport. type Transport; + ///TODO document + type ServerCtx; /// Configuration of the channel. fn config(&self) -> &Config; @@ -381,7 +387,7 @@ where /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` - fn requests(self) -> Requests + fn requests(self) -> Requests where Self: Sized, { @@ -428,17 +434,18 @@ where where Self: Sized, Self::Req: RequestName, - S: Serve + Clone, + S: Serve + Clone, { self.requests().execute(serve) } } -impl Stream for BaseChannel +impl Stream for BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, + ServerCtx: ExtractContext { - type Item = Result, ChannelError>; + type Item = Result, ChannelError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { #[derive(Clone, Copy, Debug)] @@ -541,10 +548,11 @@ where } } -impl Sink> for BaseChannel +impl Sink> for BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, T::Error: Error, + ServerCtx: ExtractContext { type Error = ChannelError; @@ -557,7 +565,7 @@ where fn start_send( mut self: Pin<&mut Self>, - response: Response, + response: Response, ) -> Result<(), Self::Error> { if let Some(span) = self .in_flight_requests_mut() @@ -591,19 +599,22 @@ where } } -impl AsRef for BaseChannel { +impl AsRef for BaseChannel { fn as_ref(&self) -> &T { self.transport.get_ref() } } -impl Channel for BaseChannel +impl Channel for BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, + ServerCtx: ExtractContext, { + type Req = Req; type Resp = Resp; type Transport = T; + type ServerCtx = ServerCtx; fn config(&self) -> &Config { &self.config @@ -621,19 +632,19 @@ where /// A stream of requests coming over a channel. `Requests` also drives the sending of responses, so /// it must be continually polled to ensure progress. #[pin_project] -pub struct Requests +pub struct Requests where C: Channel, { #[pin] channel: C, /// Responses waiting to be written to the wire. - pending_responses: mpsc::Receiver>, + pending_responses: mpsc::Receiver>, /// Handed out to request handlers to fan in responses. - responses_tx: mpsc::Sender>, + responses_tx: mpsc::Sender>, } -impl Requests +impl Requests where C: Channel, { @@ -650,14 +661,14 @@ where /// Returns the inner channel over which messages are sent and received. pub fn pending_responses_mut<'a>( self: &'a mut Pin<&mut Self>, - ) -> &'a mut mpsc::Receiver> { + ) -> &'a mut mpsc::Receiver> { self.as_mut().project().pending_responses } fn pump_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll, C::Error>>> { + ) -> Poll, C::Error>>> { self.channel_pin_mut().poll_next(cx).map_ok( |TrackedRequest { request, @@ -722,7 +733,7 @@ where fn poll_next_response( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll, C::Error>>> { + ) -> Poll, C::Error>>> { ready!(self.ensure_writeable(cx)?); match ready!(self.pending_responses_mut().poll_recv(cx)) { @@ -779,7 +790,7 @@ where pub fn execute(self, serve: S) -> impl Stream> where C::Req: RequestName, - S: Serve + Clone, + S: Serve + Clone, { self.take_while(|result| { if let Err(e) = result { @@ -795,7 +806,7 @@ where } } -impl fmt::Debug for Requests +impl fmt::Debug for Requests where C: Channel, { @@ -826,17 +837,17 @@ impl Drop for ResponseGuard { /// If dropped without calling [`execute`](InFlightRequest::execute), a cancellation message will /// be sent to the Channel to clean up associated request state. #[derive(Debug)] -pub struct InFlightRequest { - request: Request, +pub struct InFlightRequest { + request: Request, abort_registration: AbortRegistration, response_guard: ResponseGuard, span: Span, - response_tx: mpsc::Sender>, + response_tx: mpsc::Sender>, } -impl InFlightRequest { +impl InFlightRequest { /// Returns a reference to the request. - pub fn get(&self) -> &Request { + pub fn get(&self) -> &Request { &self.request } @@ -889,7 +900,7 @@ impl InFlightRequest { pub async fn execute(self, serve: S) where Req: RequestName, - S: Serve, + S: Serve, { let Self { response_tx, @@ -934,11 +945,11 @@ fn print_err(e: &(dyn Error + 'static)) -> String { .join(": ") } -impl Stream for Requests +impl Stream for Requests where C: Channel, { - type Item = Result, C::Error>; + type Item = Result, C::Error>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { @@ -984,7 +995,6 @@ mod tests { request_hook::{AfterRequest, BeforeRequest, RequestHook}, serve, }; - use crate::context::ServerContext; use crate::{ ClientMessage, Request, Response, ServerError, context, trace, transport::channel::{self, UnboundedChannel}, @@ -1002,6 +1012,7 @@ mod tests { task::Poll, time::{Duration, Instant}, }; + use crate::context::{ExtractContext, SharedContext}; fn test_channel() -> ( Pin< @@ -1010,13 +1021,14 @@ mod tests { Req, Resp, UnboundedChannel< - ClientMessage, - Response, + ClientMessage, + Response, >, + SharedContext >, >, >, - UnboundedChannel, ClientMessage>, + UnboundedChannel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::unbounded(); (Box::pin(BaseChannel::new(Config::default(), rx)), tx) @@ -1026,19 +1038,20 @@ mod tests { Pin< Box< Requests< - ServerContext, BaseChannel< Req, Resp, UnboundedChannel< - ClientMessage, - Response, + ClientMessage, + Response, >, + SharedContext >, + >, >, >, - UnboundedChannel, ClientMessage>, + UnboundedChannel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::unbounded(); ( @@ -1053,19 +1066,19 @@ mod tests { Pin< Box< Requests< - ServerContext, BaseChannel< Req, Resp, channel::Channel< - ClientMessage, - Response, + ClientMessage, + Response, >, + SharedContext >, >, >, >, - channel::Channel, ClientMessage>, + channel::Channel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::bounded(capacity); // Add 1 because capacity 0 is not supported (but is supported by transport::channel::bounded). @@ -1075,9 +1088,9 @@ mod tests { (Box::pin(BaseChannel::new(config, rx).requests()), tx) } - fn fake_request(req: Req) -> ClientMessage { + fn fake_request(req: Req) -> ClientMessage { ClientMessage::Request(Request { - context: context::ServerContext::current(), + context: context::SharedContext::current(), id: 0, message: req, }) @@ -1101,13 +1114,15 @@ mod tests { #[tokio::test] async fn serve_before_mutates_context() -> anyhow::Result<()> { struct SetDeadline(Instant); - impl BeforeRequest for SetDeadline { + impl BeforeRequest for SetDeadline where ServerCtx: ExtractContext { async fn before( &mut self, - ctx: &mut context::ServerContext, + ctx: &mut ServerCtx, _: &Req, ) -> Result<(), ServerError> { - ctx.deadline = self.0; + let mut inner = ctx.extract(); + inner.deadline = self.0; + ctx.update(inner); Ok(()) } } @@ -1115,7 +1130,7 @@ mod tests { let some_time = Instant::now() + Duration::from_secs(37); let some_other_time = Instant::now() + Duration::from_secs(83); - let serve = serve(move |ctx: &mut context::ServerContext, i| { + let serve = serve(move |ctx: &mut context::SharedContext, i| { async move { assert_eq!(ctx.deadline, some_time); Ok(i) @@ -1123,7 +1138,7 @@ mod tests { .boxed() }); let deadline_hook = serve.before(SetDeadline(some_time)); - let mut ctx = context::ServerContext::current(); + let mut ctx = context::SharedContext::current(); ctx.deadline = some_other_time; deadline_hook.serve(&mut ctx, 7).await?; Ok(()) @@ -1143,20 +1158,20 @@ mod tests { } } } - impl BeforeRequest for PrintLatency { + impl BeforeRequest for PrintLatency { async fn before( &mut self, - _: &mut context::ServerContext, + _: &mut ServerCtx, _: &Req, ) -> Result<(), ServerError> { self.start = Instant::now(); Ok(()) } } - impl AfterRequest for PrintLatency { + impl AfterRequest for PrintLatency { async fn after( &mut self, - _: &mut context::ServerContext, + _: &mut ServerCtx, _: &mut Result, ) { tracing::debug!("Elapsed: {:?}", self.start.elapsed()); @@ -1192,14 +1207,14 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::ServerContext::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); assert_matches!( channel.as_mut().start_request(Request { id: 0, - context: context::ServerContext::current(), + context: context::SharedContext::current(), message: () }), Err(AlreadyExistsError) @@ -1215,7 +1230,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::ServerContext::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1223,7 +1238,7 @@ mod tests { .as_mut() .start_request(Request { id: 1, - context: context::ServerContext::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1246,7 +1261,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::ServerContext::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1275,7 +1290,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::ServerContext::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1317,7 +1332,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::ServerContext::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1340,7 +1355,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: ServerContext::current(), + context: SharedContext::current(), message: (), }) .unwrap(); @@ -1349,7 +1364,7 @@ mod tests { .as_mut() .start_send(Response { request_id: 0, - context: ServerContext::current(), + context: SharedContext::current(), message: Ok(()), }) .unwrap(); @@ -1408,7 +1423,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 0, - context: context::ServerContext::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1417,7 +1432,7 @@ mod tests { .channel_pin_mut() .start_send(Response { request_id: 0, - context: ServerContext::current(), + context: SharedContext::current(), message: Ok(()), }) .unwrap(); @@ -1429,7 +1444,7 @@ mod tests { .responses_tx .send(Response { request_id: 1, - context: ServerContext::current(), + context: SharedContext::current(), message: Ok(()), }) .await @@ -1440,7 +1455,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 1, - context: ServerContext::current(), + context: SharedContext::current(), message: (), }) .unwrap(); @@ -1461,7 +1476,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 0, - context: context::ServerContext::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1470,7 +1485,7 @@ mod tests { .channel_pin_mut() .start_send(Response { request_id: 0, - context: ServerContext::current(), + context: SharedContext::current(), message: Ok(()), }) .unwrap(); @@ -1481,7 +1496,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 1, - context: ServerContext::current(), + context: SharedContext::current(), message: (), }) .unwrap(); @@ -1491,7 +1506,7 @@ mod tests { .responses_tx .send(Response { request_id: 1, - context: ServerContext::current(), + context: SharedContext::current(), message: Ok(()), }) .await diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index 1868cbe47..56f393b84 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -33,7 +33,7 @@ where ) -> impl Stream>> where C::Req: RequestName, - S: Serve + Clone, + S: Serve + Clone, { self.map(move |channel| channel.execute(serve.clone())) } diff --git a/tarpc/src/server/limits/channels_per_key.rs b/tarpc/src/server/limits/channels_per_key.rs index 64b644278..3ffdfac89 100644 --- a/tarpc/src/server/limits/channels_per_key.rs +++ b/tarpc/src/server/limits/channels_per_key.rs @@ -107,6 +107,7 @@ where type Req = C::Req; type Resp = C::Resp; type Transport = C::Transport; + type ServerCtx = C::ServerCtx; fn config(&self) -> &server::Config { self.inner.config() diff --git a/tarpc/src/server/limits/requests_per_channel.rs b/tarpc/src/server/limits/requests_per_channel.rs index 395ded512..383abb9c8 100644 --- a/tarpc/src/server/limits/requests_per_channel.rs +++ b/tarpc/src/server/limits/requests_per_channel.rs @@ -4,7 +4,6 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -use crate::context::ServerContext; use crate::{ Response, ServerError, server::{Channel, Config}, @@ -82,7 +81,7 @@ where } } -impl Sink::Resp>> for MaxRequests +impl Sink::Resp>> for MaxRequests where C: Channel, { @@ -94,7 +93,7 @@ where fn start_send( self: Pin<&mut Self>, - item: Response::Resp>, + item: Response::Resp>, ) -> Result<(), Self::Error> { self.project().inner.start_send(item) } @@ -121,6 +120,7 @@ where type Req = ::Req; type Resp = ::Resp; type Transport = ::Transport; + type ServerCtx = ::ServerCtx; fn in_flight_requests(&self) -> usize { self.inner.in_flight_requests() @@ -190,6 +190,7 @@ mod tests { time::{Duration, Instant}, }; use tracing::Span; + use crate::context::{ServerContext, SharedContext}; #[tokio::test] async fn throttler_in_flight_requests() { @@ -270,7 +271,7 @@ mod tests { } impl PendingSink<(), ()> { pub fn default() - -> PendingSink>, Response> + -> PendingSink>, Response> { PendingSink { ghost: PhantomData } } @@ -297,11 +298,12 @@ mod tests { } } impl Channel - for PendingSink>, Response> + for PendingSink>, Response> { type Req = Req; type Resp = Resp; type Transport = (); + type ServerCtx = ServerContext; fn config(&self) -> &Config { unimplemented!() } @@ -331,7 +333,7 @@ mod tests { .as_mut() .start_send(Response { request_id: 0, - context: ServerContext::current(), + context: SharedContext::current(), message: Ok(1), }) .unwrap(); diff --git a/tarpc/src/server/request_hook.rs b/tarpc/src/server/request_hook.rs index 38b0998bf..338059f7d 100644 --- a/tarpc/src/server/request_hook.rs +++ b/tarpc/src/server/request_hook.rs @@ -62,9 +62,9 @@ pub trait RequestHook: Serve { /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` - fn before(self, hook: Hook) -> HookThenServe + fn before(self, hook: Hook) -> HookThenServe where - Hook: BeforeRequest, + Hook: BeforeRequest, Self: Sized, { HookThenServe::new(self, hook) @@ -107,7 +107,7 @@ pub trait RequestHook: Serve { /// ``` fn after(self, hook: Hook) -> ServeThenHook where - Hook: AfterRequest, + Hook: AfterRequest, Self: Sized, { ServeThenHook::new(self, hook) @@ -133,17 +133,17 @@ pub trait RequestHook: Serve { /// /// struct PrintLatency(Instant); /// - /// impl BeforeRequest for PrintLatency { - /// async fn before(&mut self, _: &mut context::ServerContext, _: &Req) -> Result<(), ServerError> { + /// impl BeforeRequest for PrintLatency { + /// async fn before(&mut self, _: &mut ServerCtx, _: &Req) -> Result<(), ServerError> { /// self.0 = Instant::now(); /// Ok(()) /// } /// } /// - /// impl AfterRequest for PrintLatency { + /// impl AfterRequest for PrintLatency { /// async fn after( /// &mut self, - /// _: &mut context::ServerContext, + /// _: &mut ServerCtx, /// _: &mut Result, /// ) { /// tracing::info!("Elapsed: {:?}", self.0.elapsed()); @@ -160,9 +160,9 @@ pub trait RequestHook: Serve { fn before_and_after( self, hook: Hook, - ) -> HookThenServeThenHook + ) -> HookThenServeThenHook where - Hook: BeforeRequest + AfterRequest, + Hook: BeforeRequest + AfterRequest, Self: Sized, { HookThenServeThenHook::new(self, hook) diff --git a/tarpc/src/server/request_hook/after.rs b/tarpc/src/server/request_hook/after.rs index 64d65807f..ce6319e25 100644 --- a/tarpc/src/server/request_hook/after.rs +++ b/tarpc/src/server/request_hook/after.rs @@ -6,30 +6,30 @@ //! Provides a hook that runs after request execution. -use crate::{ServerError, context, server::Serve}; +use crate::{ServerError, server::Serve}; use futures::prelude::*; /// A hook that runs after request execution. #[allow(async_fn_in_trait)] -pub trait AfterRequest { +pub trait AfterRequest { /// The function that is called after request execution. /// /// The hook can modify the request context and the response. async fn after( &mut self, - ctx: &mut context::ServerContext, + ctx: &mut ServerCtx, resp: &mut Result, ); } -impl AfterRequest for F +impl AfterRequest for F where - F: FnMut(&mut context::ServerContext, &mut Result) -> Fut, + F: FnMut(&mut ServerCtx, &mut Result) -> Fut, Fut: Future, { async fn after( &mut self, - ctx: &mut context::ServerContext, + ctx: &mut ServerCtx, resp: &mut Result, ) { self(ctx, resp).await @@ -60,14 +60,15 @@ impl Clone for ServeThenHook { impl Serve for ServeThenHook where Serv: Serve, - Hook: AfterRequest, + Hook: AfterRequest, { type Req = Serv::Req; type Resp = Serv::Resp; + type ServerCtx = Serv::ServerCtx; async fn serve( self, - ctx: &mut context::ServerContext, + ctx: &mut Serv::ServerCtx, req: Serv::Req, ) -> Result { let ServeThenHook { diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs index 1f647227f..3e2e091c8 100644 --- a/tarpc/src/server/request_hook/before.rs +++ b/tarpc/src/server/request_hook/before.rs @@ -6,12 +6,13 @@ //! Provides a hook that runs before request execution. -use crate::{ServerError, context, server::Serve}; +use std::marker::PhantomData; +use crate::{ServerError, server::Serve}; use futures::prelude::*; /// A hook that runs before request execution. #[allow(async_fn_in_trait)] -pub trait BeforeRequest { +pub trait BeforeRequest { /// The function that is called before request execution. /// /// If this function returns an error, the request will not be executed and the error will be @@ -21,24 +22,24 @@ pub trait BeforeRequest { /// enforce a maximum deadline on all requests. async fn before( &mut self, - ctx: &mut context::ServerContext, + ctx: &mut ServerCtx, req: &Req, ) -> Result<(), ServerError>; } /// A list of hooks that run in order before request execution. -pub trait BeforeRequestList: BeforeRequest { +pub trait BeforeRequestList: BeforeRequest { /// The hook returned by `BeforeRequestList::then`. - type Then: BeforeRequest + type Then: BeforeRequest where - Next: BeforeRequest; + Next: BeforeRequest; /// Returns a hook that, when run, runs two hooks, first `self` and then `next`. - fn then>(self, next: Next) -> Self::Then; + fn then>(self, next: Next) -> Self::Then; /// Same as `then`, but helps the compiler with type inference when Next is a closure. fn then_fn< - Next: FnMut(&mut context::ServerContext, &Req) -> Fut, + Next: FnMut(&mut ServerCtx, &Req) -> Fut, Fut: Future>, >( self, @@ -51,21 +52,21 @@ pub trait BeforeRequestList: BeforeRequest { } /// The service fn returned by `BeforeRequestList::serving`. - type Serve>: Serve; + type Serve>: Serve; /// Runs the list of request hooks before execution of the given serve fn. /// This is equivalent to `serve.before(before_request_chain)` but may be syntactically nicer. - fn serving>(self, serve: S) -> Self::Serve; + fn serving>(self, serve: S) -> Self::Serve; } -impl BeforeRequest for F +impl BeforeRequest for F where - F: FnMut(&mut context::ServerContext, &Req) -> Fut, + F: FnMut(&mut ServerCtx, &Req) -> Fut, Fut: Future>, { async fn before( &mut self, - ctx: &mut context::ServerContext, + ctx: &mut ServerCtx, req: &Req, ) -> Result<(), ServerError> { self(ctx, req).await @@ -73,29 +74,36 @@ where } /// A Service function that runs a hook before request execution. -#[derive(Clone)] -pub struct HookThenServe { +pub struct HookThenServe { serve: Serv, hook: Hook, + ghost: PhantomData } -impl HookThenServe { +impl Clone for HookThenServe { + fn clone(&self) -> Self { + Self::new(self.serve.clone(), self.hook.clone()) + } +} + +impl HookThenServe { pub(crate) fn new(serve: Serv, hook: Hook) -> Self { - Self { serve, hook } + Self { serve, hook, ghost: PhantomData } } } -impl Serve for HookThenServe +impl Serve for HookThenServe where - Serv: Serve, - Hook: BeforeRequest, + Serv: Serve, + Hook: BeforeRequest, { + type ServerCtx = ServerCtx; type Req = Serv::Req; type Resp = Serv::Resp; async fn serve( self, - ctx: &mut context::ServerContext, + ctx: &mut ServerCtx, req: Self::Req, ) -> Result { let HookThenServe { @@ -129,7 +137,7 @@ where /// Ok(()) /// }) /// .serving(serve(|_ctx, i| async move { Ok(i + 1) }.boxed())); -/// let mut context = context::ServerContext::current(); +/// let mut context = context::SharedContext::current(); /// let response = serve.clone().serve(&mut context, 1); /// assert!(block_on(response).is_ok()); /// assert!(i.get() == 2); @@ -146,12 +154,12 @@ pub struct BeforeRequestCons(First, Rest); #[derive(Clone, Copy)] pub struct BeforeRequestNil; -impl, Rest: BeforeRequest> BeforeRequest +impl, Rest: BeforeRequest, ServerCtx> BeforeRequest for BeforeRequestCons { async fn before( &mut self, - ctx: &mut context::ServerContext, + ctx: &mut ServerCtx, req: &Req, ) -> Result<(), ServerError> { let BeforeRequestCons(first, rest) = self; @@ -161,45 +169,45 @@ impl, Rest: BeforeRequest> BeforeRequest BeforeRequest for BeforeRequestNil { - async fn before(&mut self, _: &mut context::ServerContext, _: &Req) -> Result<(), ServerError> { +impl BeforeRequest for BeforeRequestNil { + async fn before(&mut self, _: &mut ServerCtx, _: &Req) -> Result<(), ServerError> { Ok(()) } } -impl, Rest: BeforeRequestList> BeforeRequestList +impl, Rest: BeforeRequestList, ServerCtx> BeforeRequestList for BeforeRequestCons { type Then = BeforeRequestCons> where - Next: BeforeRequest; + Next: BeforeRequest; - fn then>(self, next: Next) -> Self::Then { + fn then>(self, next: Next) -> Self::Then { let BeforeRequestCons(first, rest) = self; BeforeRequestCons(first, rest.then(next)) } - type Serve> = HookThenServe; + type Serve> = HookThenServe; - fn serving>(self, serve: S) -> Self::Serve { + fn serving>(self, serve: S) -> Self::Serve { HookThenServe::new(serve, self) } } -impl BeforeRequestList for BeforeRequestNil { +impl BeforeRequestList for BeforeRequestNil { type Then = BeforeRequestCons where - Next: BeforeRequest; + Next: BeforeRequest; - fn then>(self, next: Next) -> Self::Then { + fn then>(self, next: Next) -> Self::Then { BeforeRequestCons(next, BeforeRequestNil) } - type Serve> = S; + type Serve> = S; - fn serving>(self, serve: S) -> S { + fn serving>(self, serve: S) -> S { serve } } @@ -223,7 +231,7 @@ fn before_request_list() { Ok(()) }) .serving(serve(|_ctx, i| async move { Ok(i + 1) }.boxed())); - let mut context = context::ServerContext::current(); + let mut context = crate::context::SharedContext::current(); let response = serve.clone().serve(&mut context, 1); assert!(block_on(response).is_ok()); assert!(i.get() == 2); diff --git a/tarpc/src/server/request_hook/before_and_after.rs b/tarpc/src/server/request_hook/before_and_after.rs index dff0abe0b..080c53b21 100644 --- a/tarpc/src/server/request_hook/before_and_after.rs +++ b/tarpc/src/server/request_hook/before_and_after.rs @@ -7,17 +7,17 @@ //! Provides a hook that runs both before and after request execution. use super::{after::AfterRequest, before::BeforeRequest}; -use crate::{RequestName, ServerError, context, server::Serve}; +use crate::{RequestName, ServerError, server::Serve}; use std::marker::PhantomData; /// A Service function that runs a hook both before and after request execution. -pub struct HookThenServeThenHook { +pub struct HookThenServeThenHook { serve: Serv, hook: Hook, - fns: PhantomData<(fn(Req), fn(Resp))>, + fns: PhantomData<(Req, Resp, ServerCtx)>, } -impl HookThenServeThenHook { +impl HookThenServeThenHook { pub(crate) fn new(serve: Serv, hook: Hook) -> Self { Self { serve, @@ -27,7 +27,7 @@ impl HookThenServeThenHook { } } -impl Clone for HookThenServeThenHook { +impl Clone for HookThenServeThenHook { fn clone(&self) -> Self { Self { serve: self.serve.clone(), @@ -37,18 +37,19 @@ impl Clone for HookThenServeThenHook Serve for HookThenServeThenHook +impl Serve for HookThenServeThenHook where Req: RequestName, - Serv: Serve, - Hook: BeforeRequest + AfterRequest, + Serv: Serve, + Hook: BeforeRequest + AfterRequest, { type Req = Req; type Resp = Resp; + type ServerCtx = ServerCtx; async fn serve( self, - ctx: &mut context::ServerContext, + ctx: &mut ServerCtx, req: Req, ) -> Result { let HookThenServeThenHook { diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index 63b65d697..9a941f711 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -4,7 +4,7 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -use crate::context::ServerContext; +use crate::context::{SharedContext}; use crate::{ Request, Response, cancellations::{CanceledRequests, RequestCancellation, cancellations}, @@ -39,8 +39,8 @@ where } } -impl Sink> - for FakeChannel> +impl Sink> + for FakeChannel> { type Error = io::Error; @@ -50,7 +50,7 @@ impl Sink> fn start_send( mut self: Pin<&mut Self>, - response: Response, + response: Response, ) -> Result<(), Self::Error> { self.as_mut() .project() @@ -72,13 +72,14 @@ impl Sink> } impl Channel - for FakeChannel>, Response> + for FakeChannel>, Response> where Req: Unpin, { type Req = Req; type Resp = Resp; type Transport = (); + type ServerCtx = SharedContext; fn config(&self) -> &Config { &self.config @@ -93,16 +94,16 @@ where } } -impl FakeChannel>, Response> { +impl FakeChannel>, Response> { pub fn push_req(&mut self, id: u64, message: Req) { let (_, abort_registration) = futures::future::AbortHandle::new_pair(); let (request_cancellation, _) = cancellations(); self.stream.push_back(Ok(TrackedRequest { request: Request { - context: context::ServerContext::new(context::SharedContext { + context: context::SharedContext { deadline: Instant::now(), trace_context: Default::default(), - }), + }, id, message, }, @@ -119,8 +120,13 @@ impl FakeChannel>, Response { pub fn default() - -> FakeChannel>, Response> { + -> FakeChannel>, Response> { let (request_cancellation, canceled_requests) = cancellations(); + + let mut x = anymap3::AnyMap::new(); + + x.entry::<&str>(); + FakeChannel { stream: Default::default(), sink: Default::default(), diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 54fadf77d..05f1790d0 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -24,6 +24,7 @@ pub trait ColorProtocol { struct ColorServer; impl ColorProtocol for ColorServer { + type Context = ServerContext; async fn get_opposite_color(self, _: &mut context::ServerContext, color: TestData) -> TestData { match color { TestData::White => TestData::Black, diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index b65a66104..ee44c58d8 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -14,6 +14,7 @@ use tarpc::{ transport::channel, }; use tokio::join; +use tarpc::context::{ServerContext}; #[tarpc_plugins::service] trait Service { @@ -25,11 +26,12 @@ trait Service { struct Server; impl Service for Server { - async fn add(self, _: &mut context::ServerContext, x: i32, y: i32) -> i32 { + type Context = ServerContext; + async fn add(self, _: &mut Self::Context, x: i32, y: i32) -> i32 { x + y } - async fn hey(self, _: &mut context::ServerContext, name: String) -> String { + async fn hey(self, _: &mut Self::Context, name: String) -> String { format!("Hey, {name}.") } } @@ -67,7 +69,8 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { struct LoopServer; impl Loop for LoopServer { - async fn r#loop(self, _: &mut context::ServerContext) { + type Context = ServerContext; + async fn r#loop(self, _: &mut Self::Context) { loop { futures::pending!(); } @@ -284,7 +287,8 @@ async fn counter() -> anyhow::Result<()> { struct CountService(u32); impl Counter for &mut CountService { - async fn count(self, _: &mut context::ServerContext) -> u32 { + type Context = ServerContext; + async fn count(self, _: &mut Self::Context) -> u32 { self.0 += 1; self.0 } From 0045581c4f1ac30be739550e3ee8e4d67f54d81a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Wed, 26 Nov 2025 13:25:41 +0100 Subject: [PATCH 13/26] remove servercontext entirely --- example-service/src/client.rs | 2 +- example-service/src/server.rs | 12 ++--- plugins/src/lib.rs | 8 +-- plugins/tests/service.rs | 22 ++++---- tarpc/examples/compression.rs | 9 ++-- tarpc/examples/custom_transport.rs | 9 ++-- tarpc/examples/pubsub.rs | 16 +++--- tarpc/examples/readme.rs | 6 +-- tarpc/examples/tls_over_tcp.rs | 9 ++-- tarpc/examples/tracing.rs | 18 +++---- tarpc/src/client/stub.rs | 9 ++-- tarpc/src/context.rs | 47 ++++------------- tarpc/src/lib.rs | 6 +-- tarpc/src/server.rs | 18 +++---- tarpc/src/server/incoming.rs | 2 +- .../src/server/limits/requests_per_channel.rs | 8 +-- tarpc/src/server/request_hook.rs | 10 ++-- tarpc/src/transport/channel.rs | 50 +++---------------- tarpc/tests/dataservice.rs | 9 ++-- tarpc/tests/service_functional.rs | 12 ++--- 20 files changed, 102 insertions(+), 180 deletions(-) diff --git a/example-service/src/client.rs b/example-service/src/client.rs index 71c9704ea..40402867f 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -9,7 +9,7 @@ use futures::{SinkExt, future}; use service::{WorldClient, init_tracing}; use std::{net::SocketAddr, time::Duration}; use tarpc::context::ClientContext; -use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; +use tarpc::transport::channel::{map_transport_to_client}; use tarpc::{client, tokio_serde::formats::Json}; use tokio::time::sleep; use tracing::Instrument; diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 5e176dfa2..019a2d7b1 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -16,8 +16,7 @@ use std::{ net::{IpAddr, Ipv6Addr, SocketAddr}, time::Duration, }; -use tarpc::context::{ServerContext, SharedContext}; -use tarpc::transport::channel::{map_transport_to_server}; +use tarpc::context::{SharedContext}; use tarpc::{ ClientMessage, context, server::{self, Channel, incoming::Incoming}, @@ -38,8 +37,8 @@ struct Flags { struct HelloServer(SocketAddr); impl World for HelloServer { - type Context = ServerContext; - async fn hello(self, _: &mut context::ServerContext, name: String) -> String { + type Context = SharedContext; + async fn hello(self, _: &mut Self::Context, name: String) -> String { let sleep_time = Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng())); time::sleep(sleep_time).await; @@ -66,14 +65,13 @@ async fn main() -> anyhow::Result<()> { listener // Ignore accept errors. .filter_map(|r| future::ready(r.ok())) - .map(map_transport_to_server) .map(server::BaseChannel::with_defaults) // Limit channels to 1 per IP. - .max_channels_per_key(1, |t| t.transport().get_ref().get_ref().get_ref().peer_addr().unwrap().ip()) + .max_channels_per_key(1, |t| t.transport().get_ref().peer_addr().unwrap().ip()) // serve is generated by the service attribute. It takes as input any type implementing // the generated World trait. .map(|channel| { - let server = HelloServer(channel.transport().get_ref().get_ref().get_ref().peer_addr().unwrap()); + let server = HelloServer(channel.transport().get_ref().peer_addr().unwrap()); channel.execute(server.serve()).for_each(spawn) }) // Max 10 channels. diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 432b2f1c8..250ffff04 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -375,8 +375,10 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// # Example /// /// ```no_run -/// use tarpc::{client, transport, service, server::{self, Channel}, context::ServerContext}; -/// use futures_util::{TryStreamExt, sink::SinkExt}; +/// use tarpc::{client, transport, service, server::{self, Channel}}; +/// use futures_util::{TryStreamExt, sink::SinkExt};/// +/// +/// use tarpc::context::SharedContext; /// /// #[service] /// pub trait Calculator { @@ -402,7 +404,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// #[derive(Clone)] /// struct CalculatorServer; /// impl Calculator for CalculatorServer { -/// type Context = ServerContext; +/// type Context = SharedContext; /// async fn add(self, context: &mut Self::Context, a: i32, b: i32) -> i32 { /// a + b /// } diff --git a/plugins/tests/service.rs b/plugins/tests/service.rs index ef49b9666..d8213f4d4 100644 --- a/plugins/tests/service.rs +++ b/plugins/tests/service.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; use std::hash::Hash; use tarpc::context; -use tarpc::context::ServerContext; +use tarpc::context::SharedContext; #[test] fn att_service_trait() { @@ -13,21 +13,21 @@ fn att_service_trait() { } impl Foo for () { - type Context = ServerContext; + type Context = SharedContext; async fn two_part( self, - _: &mut context::ServerContext, + _: &mut context::SharedContext, s: String, i: i32, ) -> (String, i32) { (s, i) } - async fn bar(self, _: &mut context::ServerContext, s: String) -> String { + async fn bar(self, _: &mut Self::Context, s: String) -> String { s } - async fn baz(self, _: &mut context::ServerContext) {} + async fn baz(self, _: &mut Self::Context) {} } } @@ -44,21 +44,21 @@ fn raw_idents() { } impl r#trait for () { - type Context = ServerContext; + type Context = SharedContext; async fn r#await( self, - _: &mut context::ServerContext, + _: &mut Self::Context, r#struct: r#yield, r#enum: i32, ) -> (r#yield, i32) { (r#struct, r#enum) } - async fn r#fn(self, _: &mut context::ServerContext, r#impl: r#yield) -> r#yield { + async fn r#fn(self, _: &mut Self::Context, r#impl: r#yield) -> r#yield { r#impl } - async fn r#async(self, _: &mut context::ServerContext) {} + async fn r#async(self, _: &mut Self::Context) {} } } @@ -72,8 +72,8 @@ fn service_with_cfg_rpc() { } impl Foo for () { - type Context = ServerContext; - async fn foo(self, _: &mut context::ServerContext) {} + type Context = SharedContext; + async fn foo(self, _: &mut Self::Context) {} } } diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index c00ffc9f3..6a1440bd2 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -9,14 +9,14 @@ use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt, prelude::*}; use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; use std::{io, io::Read, io::Write}; -use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; +use tarpc::transport::channel::{map_transport_to_client}; use tarpc::{ client, context, serde_transport::tcp, server::{BaseChannel, Channel}, tokio_serde::formats::Bincode, }; -use tarpc::context::ServerContext; +use tarpc::context::SharedContext; /// Type of compression that should be enabled on the request. The transport is free to ignore this. #[derive(Debug, PartialEq, Eq, Clone, Copy, Deserialize, Serialize)] @@ -110,8 +110,8 @@ pub trait World { struct HelloServer; impl World for HelloServer { - type Context = ServerContext; - async fn hello(self, _: &mut context::ServerContext, name: String) -> String { + type Context = SharedContext; + async fn hello(self, _: &mut Self::Context, name: String) -> String { format!("Hey, {name}!") } } @@ -128,7 +128,6 @@ async fn main() -> anyhow::Result<()> { tokio::spawn(async move { let transport = incoming.next().await.unwrap().unwrap(); let transport = add_compression(transport); - let transport = map_transport_to_server(transport); BaseChannel::with_defaults(transport) .execute(HelloServer.serve()) .for_each(spawn) diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index 415cb5442..92c723b4d 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -6,12 +6,12 @@ use console_subscriber::Server; use futures::prelude::*; -use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::context::{ClientContext, SharedContext}; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; -use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; +use tarpc::transport::channel::{map_transport_to_client}; use tokio::net::{UnixListener, UnixStream}; #[tarpc::service] @@ -23,8 +23,8 @@ pub trait PingService { struct Service; impl PingService for Service { - type Context = ServerContext; - async fn ping(self, _: &mut ServerContext) {} + type Context = SharedContext; + async fn ping(self, _: &mut Self::Context) {} } #[tokio::main] async fn main() -> anyhow::Result<()> { @@ -42,7 +42,6 @@ async fn main() -> anyhow::Result<()> { let (conn, _addr) = listener.accept().await.unwrap(); let framed = codec_builder.new_framed(conn); let transport = transport::new(framed, Bincode::default()); - let transport = map_transport_to_server(transport); let fut = BaseChannel::with_defaults(transport) .execute(Service.serve()) diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 6755e49ca..fbe19078a 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -48,8 +48,8 @@ use std::{ sync::{Arc, Mutex, RwLock}, }; use subscriber::Subscriber as _; -use tarpc::context::{ClientContext, ServerContext, SharedContext}; -use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; +use tarpc::context::{ClientContext, SharedContext}; +use tarpc::transport::channel::{map_transport_to_client}; use tarpc::{ ClientMessage, client, context, serde_transport::tcp, @@ -82,12 +82,12 @@ struct Subscriber { } impl subscriber::Subscriber for Subscriber { - type Context = ServerContext; - async fn topics(self, _: &mut context::ServerContext) -> Vec { + type Context = SharedContext; + async fn topics(self, _: &mut Self::Context) -> Vec { self.topics.clone() } - async fn receive(self, _: &mut context::ServerContext, topic: String, message: String) { + async fn receive(self, _: &mut Self::Context, topic: String, message: String) { info!(local_addr = %self.local_addr, %topic, %message, "ReceivedMessage") } } @@ -107,7 +107,6 @@ impl Subscriber { ) -> anyhow::Result { let publisher = tcp::connect(publisher_addr, Json::default).await?; let local_addr = publisher.local_addr()?; - let publisher = map_transport_to_server(publisher); let mut handler = server::BaseChannel::with_defaults(publisher).requests(); let subscriber = Subscriber { local_addr, topics }; // The first request is for the topics being subscribed to. @@ -168,7 +167,6 @@ impl Publisher { let publisher = connecting_publishers.next().await.unwrap().unwrap(); info!(publisher.peer_addr = ?publisher.peer_addr(), "publisher connected."); - let publisher = map_transport_to_server(publisher); server::BaseChannel::with_defaults(publisher) .execute(self.serve()) @@ -272,8 +270,8 @@ impl Publisher { } impl publisher::Publisher for Publisher { - type Context = ServerContext; - async fn publish(self, _: &mut context::ServerContext, topic: String, message: String) { + type Context = SharedContext; + async fn publish(self, _: &mut Self::Context, topic: String, message: String) { info!("received message to publish."); let mut subscribers = match self.subscriptions.read().unwrap().get(&topic) { None => return, diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index c7e8de00b..db93d2e74 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -5,7 +5,7 @@ // https://opensource.org/licenses/MIT. use futures::prelude::*; -use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::context::{ClientContext, SharedContext}; use tarpc::{ ClientMessage, client, context, server::{self, Channel}, @@ -25,8 +25,8 @@ pub trait World { struct HelloServer; impl World for HelloServer { - type Context = ServerContext; - async fn hello(self, _: &mut context::ServerContext, name: String) -> String { + type Context = SharedContext; + async fn hello(self, _: &mut Self::Context, name: String) -> String { format!("Hello, {name}!") } } diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index 4ed3298bb..0e00cdca8 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -10,12 +10,12 @@ use std::io::{self, BufReader, Cursor}; use std::net::{IpAddr, Ipv4Addr}; use std::sync::Arc; -use tarpc::context::{ClientContext, ServerContext, SharedContext}; +use tarpc::context::{ClientContext, SharedContext}; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; -use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; +use tarpc::transport::channel::{map_transport_to_client}; use tokio::net::TcpListener; use tokio::net::TcpStream; use tokio_rustls::rustls::{ @@ -33,8 +33,8 @@ pub trait PingService { struct Service; impl PingService for Service { - type Context = ServerContext; - async fn ping(self, _: &mut ServerContext) -> String { + type Context = SharedContext; + async fn ping(self, _: &mut Self::Context) -> String { "🔒".to_owned() } } @@ -116,7 +116,6 @@ async fn main() -> anyhow::Result<()> { let framed = codec_builder.new_framed(tls_stream); let transport = transport::new(framed, Bincode::default()); - let transport = map_transport_to_server(transport); let fut = BaseChannel::with_defaults(transport) .execute(Service.serve()) diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index f747c9d75..b69e0c1a0 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -19,8 +19,8 @@ use std::{ atomic::{AtomicBool, Ordering}, }, }; -use tarpc::context::{ClientContext, ServerContext, SharedContext}; -use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; +use tarpc::context::{ClientContext, SharedContext}; +use tarpc::transport::channel::{map_transport_to_client}; use tarpc::{ ClientMessage, RequestName, Response, ServerError, Transport, client::{ @@ -58,8 +58,8 @@ pub mod double { struct AddServer; impl AddService for AddServer { - type Context = ServerContext; - async fn add(self, _: &mut context::ServerContext, x: i32, y: i32) -> i32 { + type Context = SharedContext; + async fn add(self, _: &mut Self::Context, x: i32, y: i32) -> i32 { x + y } } @@ -73,8 +73,8 @@ impl DoubleService for DoubleServer where Stub: AddStub + Clone + Send + Sync + 'static, { - type Context = ServerContext; - async fn double(self, _: &mut context::ServerContext, x: i32) -> Result { + type Context = SharedContext; + async fn double(self, _: &mut Self::Context, x: i32) -> Result { self.add_client .add(&mut context::ClientContext::current(), x, x) .await @@ -180,7 +180,6 @@ async fn main() -> anyhow::Result<()> { .serving(AddServer.serve()); let add_server = add_listener1 .chain(add_listener2) - .map(map_transport_to_server) .map(BaseChannel::with_defaults); tokio::spawn(spawn_incoming(add_server.execute(server))); @@ -191,9 +190,8 @@ async fn main() -> anyhow::Result<()> { let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) .await? - .filter_map(|r| future::ready(r.ok())) - .map(map_transport_to_server); - let addr = double_listener.get_ref().get_ref().local_addr(); + .filter_map(|r| future::ready(r.ok())); + let addr = double_listener.get_ref().local_addr(); let double_server = double_listener.map(BaseChannel::with_defaults).take(1); let server = DoubleServer { add_client }.serve(); tokio::spawn(spawn_incoming(double_server.execute(server))); diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index 6fa159dd7..9989f0577 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -3,10 +3,9 @@ use crate::{ RequestName, client::{Channel, RpcError}, - context, server::Serve, }; -use crate::context::{ClientContext, ServerContext}; +use crate::context::{ClientContext, SharedContext}; pub mod load_balance; pub mod retry; @@ -54,7 +53,7 @@ where impl Stub for S where - S: Serve + Clone, + S: Serve + Clone, { type Req = S::Req; type Resp = S::Resp; @@ -64,7 +63,7 @@ where ctx: &mut ClientContext, req: Self::Req, ) -> Result { - let mut server_ctx = context::ServerContext::new(ctx.shared_context.clone()); + let mut server_ctx = ctx.shared_context.clone(); let res = self .clone() @@ -72,7 +71,7 @@ where .await .map_err(RpcError::Server); - ctx.shared_context = server_ctx.shared_context; + ctx.shared_context = server_ctx; res } diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index 798044c93..5cc9389f1 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -56,43 +56,6 @@ impl ExtractContext for T where T: Clone { } } -/// Request context that carries request-scoped server side information like deadlines and trace information -/// as well as any server side extensions defined by the transport, hooks or service implementations. -/// It is build from the shared context sent from client to server. -/// -/// The context should not be stored directly in a server implementation, because the context will -/// be different for each request in scope. -#[derive(Debug)] -pub struct ServerContext { - /// Shared context sent from client to server which contains information used by both sides. - pub shared_context: SharedContext, -} - -impl ServerContext { - /// Creates a new ServerContext from the given SharedContext with no extensions. - pub fn new(shared_context: SharedContext) -> Self { - Self { shared_context } - } - - /// Creates a new ServerContext for the current shared context with no extensions. - pub fn current() -> Self { - Self::new(SharedContext::current()) - } -} - -impl Deref for ServerContext { - type Target = SharedContext; - - fn deref(&self) -> &Self::Target { - &self.shared_context - } -} -impl DerefMut for ServerContext { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.shared_context - } -} - /// Request context that carries request-scoped client side information like deadlines and trace information /// as well as any server side extensions defined by the transport, hooks and stubs. /// The shared part of the context is sent from client to server, while the client side extensions are only seen on the client side. @@ -103,12 +66,20 @@ impl DerefMut for ServerContext { pub struct ClientContext { /// Shared context sent from client to server which contains information used by both sides. pub shared_context: SharedContext, + + /// Client side extensions that are not seen by the server + /// XXX, YYY, and ZZZ can use this to store per-request data, and communicate with eachother. + /// Note that this is NOT sent to the server, and they will always see an empty map here. + pub client_context: anymap3::Map, } impl ClientContext { /// Creates a new ServerContext from the given SharedContext with no extensions. pub fn new(shared_context: SharedContext) -> Self { - Self { shared_context } + Self { + shared_context, + client_context: anymap3::Map::new(), + } } /// Creates a new ServerContext for the current shared context with no extensions. diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 0578a392f..e0869d9f6 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -124,7 +124,7 @@ //! struct HelloServer; //! //! impl World for HelloServer { -//! type Context = context::ServerContext; +//! type Context = context::SharedContext; //! // Each defined rpc generates an async fn that serves the RPC //! async fn hello(self, _: &mut Self::Context, name: String) -> String { //! format!("Hello, {name}!") @@ -145,7 +145,7 @@ //! # use tarpc::{ //! # ClientMessage, //! # client, context, -//! # context::{ClientContext, ServerContext, SharedContext}, +//! # context::{ClientContext, SharedContext}, //! # transport::channel, //! # server::{self, Channel}, //! # }; @@ -161,7 +161,7 @@ //! # #[derive(Clone)] //! # struct HelloServer; //! # impl World for HelloServer { -//! # type Context = ServerContext; +//! # type Context = SharedContext; //! # // Each defined rpc generates an async fn that serves the RPC //! # async fn hello(self, _: &mut Self::Context, name: String) -> String { //! # format!("Hello, {name}!") diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index fe7440f7e..e6c395836 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -363,7 +363,7 @@ where /// use tarpc::{ /// ClientMessage, /// context, - /// context::{ClientContext, SharedContext, ServerContext}, + /// context::{ClientContext, SharedContext}, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, serve}, /// transport, @@ -407,7 +407,7 @@ where /// # Example /// /// ```rust - /// use tarpc::{ClientMessage, context, client, server::{self, BaseChannel, Channel, serve}, transport, context::{ClientContext, SharedContext, ServerContext}}; + /// use tarpc::{ClientMessage, context, client, server::{self, BaseChannel, Channel, serve}, transport, context::{ClientContext, SharedContext}}; /// use futures::prelude::*; /// use tracing_subscriber::prelude::*; /// @@ -767,7 +767,7 @@ where /// /// ```rust /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport, ClientMessage}; - /// use tarpc::context::{ClientContext, SharedContext, ServerContext}; + /// use tarpc::context::{ClientContext, SharedContext}; /// use futures::prelude::*; /// /// # #[cfg(not(feature = "tokio1"))] @@ -872,7 +872,7 @@ impl InFlightRequest { /// use tarpc::{ /// ClientMessage, /// context, - /// context::{ClientContext, SharedContext, ServerContext}, + /// context::{ClientContext, SharedContext}, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, serve}, /// transport, @@ -1106,7 +1106,7 @@ mod tests { async fn test_serve() { let serve = serve(|_, i| async move { Ok(i) }.boxed()); assert_matches!( - serve.serve(&mut context::ServerContext::current(), 7).await, + serve.serve(&mut context::SharedContext::current(), 7).await, Ok(7) ); } @@ -1178,10 +1178,10 @@ mod tests { } } - let serve = serve(move |_: &mut context::ServerContext, i| async move { Ok(i) }.boxed()); + let serve = serve(move |_: &mut context::SharedContext, i| async move { Ok(i) }.boxed()); serve .before_and_after(PrintLatency::new()) - .serve(&mut context::ServerContext::current(), 7) + .serve(&mut context::SharedContext::current(), 7) .await?; Ok(()) } @@ -1189,11 +1189,11 @@ mod tests { #[tokio::test] async fn serve_before_error_aborts_request() -> anyhow::Result<()> { let serve = serve(|_, _| async { panic!("Shouldn't get here") }.boxed()); - let deadline_hook = serve.before(|_: &mut context::ServerContext, _: &i32| async { + let deadline_hook = serve.before(|_: &mut context::SharedContext, _: &i32| async { Err(ServerError::new(io::ErrorKind::Other, "oops".into())) }); let resp: Result = deadline_hook - .serve(&mut context::ServerContext::current(), 7) + .serve(&mut context::SharedContext::current(), 7) .await; assert_matches!(resp, Err(_)); Ok(()) diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index 56f393b84..2baa27c89 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -50,7 +50,7 @@ where /// use tarpc::{ /// ClientMessage, /// context, -/// context::{ClientContext, ServerContext, SharedContext}, +/// context::{ClientContext, SharedContext}, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, incoming::{Incoming, spawn_incoming}, serve}, /// transport, diff --git a/tarpc/src/server/limits/requests_per_channel.rs b/tarpc/src/server/limits/requests_per_channel.rs index 383abb9c8..deb723bda 100644 --- a/tarpc/src/server/limits/requests_per_channel.rs +++ b/tarpc/src/server/limits/requests_per_channel.rs @@ -190,7 +190,7 @@ mod tests { time::{Duration, Instant}, }; use tracing::Span; - use crate::context::{ServerContext, SharedContext}; + use crate::context::{SharedContext}; #[tokio::test] async fn throttler_in_flight_requests() { @@ -271,7 +271,7 @@ mod tests { } impl PendingSink<(), ()> { pub fn default() - -> PendingSink>, Response> + -> PendingSink>, Response> { PendingSink { ghost: PhantomData } } @@ -298,12 +298,12 @@ mod tests { } } impl Channel - for PendingSink>, Response> + for PendingSink>, Response> { type Req = Req; type Resp = Resp; type Transport = (); - type ServerCtx = ServerContext; + type ServerCtx = SharedContext; fn config(&self) -> &Config { unimplemented!() } diff --git a/tarpc/src/server/request_hook.rs b/tarpc/src/server/request_hook.rs index 338059f7d..4f3d60377 100644 --- a/tarpc/src/server/request_hook.rs +++ b/tarpc/src/server/request_hook.rs @@ -48,7 +48,7 @@ pub trait RequestHook: Serve { /// use std::io; /// /// let serve = serve(|_ctx, i| async move { Ok(i + 1) }.boxed()) - /// .before(|_ctx: &mut context::ServerContext, req: &i32| { + /// .before(|_ctx: &mut context::SharedContext, req: &i32| { /// future::ready( /// if *req == 1 { /// Err(ServerError::new( @@ -58,7 +58,7 @@ pub trait RequestHook: Serve { /// Ok(()) /// }) /// }); - /// let mut context = context::ServerContext::current(); + /// let mut context = context::SharedContext::current(); /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` @@ -95,13 +95,13 @@ pub trait RequestHook: Serve { /// Ok(i + 1) /// } /// }.boxed()) - /// .after(|_ctx: &mut context::ServerContext, resp: &mut Result| { + /// .after(|_ctx: &mut context::SharedContext, resp: &mut Result| { /// if let Err(e) = resp { /// eprintln!("server error: {e:?}"); /// } /// future::ready(()) /// }); - /// let mut context = context::ServerContext::current(); + /// let mut context = context::SharedContext::current(); /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` @@ -153,7 +153,7 @@ pub trait RequestHook: Serve { /// let serve = serve(|_ctx, i| async move { /// Ok(i + 1) /// }.boxed()).before_and_after(PrintLatency(Instant::now())); - /// let mut context = context::ServerContext::current(); + /// let mut context = context::SharedContext::current(); /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_ok()); /// ``` diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 7615d8fe1..476f60738 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -6,13 +6,14 @@ //! Transports backed by in-memory channels. -use crate::context::{ClientContext, ServerContext, SharedContext}; +use crate::context::{ClientContext, SharedContext}; use crate::{ClientMessage, Response, Transport}; use futures::future::{Ready}; use futures::sink::With; use futures::{Sink, SinkExt, Stream, TryStreamExt, task::*}; use pin_project::pin_project; use std::{error::Error, future, pin::Pin}; +use std::convert::identity; use tokio::sync::mpsc; /// Errors that occur in the sending or receiving of messages over a channel. @@ -87,13 +88,13 @@ where /// [`Sink`]. pub fn unbounded_for_client_server_context() -> ( impl Transport, Response>, - impl Transport, ClientMessage>, + impl Transport, ClientMessage>, ) { unbounded_mapped( map_req_client_context_to_shared, - map_req_shared_context_to_server, + identity, map_resp_shared_context_to_client, - map_resp_server_context_to_shared, + identity, ) } @@ -104,21 +105,7 @@ fn map_req_client_context_to_shared( msg.map_context(|ctx| ctx.shared_context) } -/// Convenience function to map a ClientMessage with SharedContext to one with ServerContext. -fn map_req_shared_context_to_server( - msg: ClientMessage, -) -> ClientMessage { - msg.map_context(ServerContext::new) -} - -/// Convenience function to map a ClientMessage with ClientContext to one with SharedContext. -fn map_resp_server_context_to_shared( - resp: Response, -) -> Response { - resp.map_context(|ctx| ctx.shared_context) -} - -/// Convenience function to map a ClientMessage with SharedContext to one with ServerContext. +/// Convenience function to map a ClientMessage with SharedContext to one with ClientContext. fn map_resp_shared_context_to_client( msg: Response, ) -> Response { @@ -148,31 +135,6 @@ where t.with(f).map_ok(map_resp_shared_context_to_client) } -/// TODO: document -/// -/// Yuck, but impl trait will loose our ability to do t.as_ref() -pub fn map_transport_to_server( - t: T, -) -> futures::stream::MapOk< - With< - T, - Response, - Response, - Ready, E>>, - fn(Response) -> Ready, E>>, - >, - fn(ClientMessage) -> ClientMessage, -> -where - T: Transport, ClientMessage>, - E: From -{ - let f: fn(Response) -> Ready, E>> = |resp| futures::future::ok(map_resp_server_context_to_shared(resp)); - - t.with(f) - .map_ok(map_req_shared_context_to_server) -} - /// A bi-directional channel backed by an [`UnboundedSender`](mpsc::UnboundedSender) /// and [`UnboundedReceiver`](mpsc::UnboundedReceiver). #[derive(Debug)] diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 05f1790d0..1a1b5207a 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -1,6 +1,6 @@ use futures::prelude::*; -use tarpc::context::{ClientContext, ServerContext, SharedContext}; -use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; +use tarpc::context::{ClientContext, SharedContext}; +use tarpc::transport::channel::{map_transport_to_client}; use tarpc::{ClientMessage, serde_transport}; use tarpc::{ client, context, @@ -24,8 +24,8 @@ pub trait ColorProtocol { struct ColorServer; impl ColorProtocol for ColorServer { - type Context = ServerContext; - async fn get_opposite_color(self, _: &mut context::ServerContext, color: TestData) -> TestData { + type Context = SharedContext; + async fn get_opposite_color(self, _: &mut Self::Context, color: TestData) -> TestData { match color { TestData::White => TestData::Black, TestData::Black => TestData::White, @@ -46,7 +46,6 @@ async fn test_call() -> anyhow::Result<()> { transport .take(1) .filter_map(|r| async { r.ok() }) - .map(map_transport_to_server) .map(BaseChannel::with_defaults) .execute(ColorServer.serve()) .map(|channel| channel.for_each(spawn)) diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index ee44c58d8..4692b17cd 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -4,7 +4,7 @@ use futures::{ prelude::*, }; use std::time::{Duration, Instant}; -use tarpc::transport::channel::{map_transport_to_client, map_transport_to_server}; +use tarpc::transport::channel::{map_transport_to_client}; use tarpc::{ ClientMessage, client::{self}, @@ -14,7 +14,7 @@ use tarpc::{ transport::channel, }; use tokio::join; -use tarpc::context::{ServerContext}; +use tarpc::context::SharedContext; #[tarpc_plugins::service] trait Service { @@ -26,7 +26,7 @@ trait Service { struct Server; impl Service for Server { - type Context = ServerContext; + type Context = SharedContext; async fn add(self, _: &mut Self::Context, x: i32, y: i32) -> i32 { x + y } @@ -69,7 +69,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { struct LoopServer; impl Loop for LoopServer { - type Context = ServerContext; + type Context = SharedContext; async fn r#loop(self, _: &mut Self::Context) { loop { futures::pending!(); @@ -118,7 +118,6 @@ async fn serde_tcp() -> anyhow::Result<()> { transport .take(1) .filter_map(|r| async { r.ok() }) - .map(map_transport_to_server) .map(BaseChannel::with_defaults) .execute(Server.serve()) .map(|channel| channel.for_each(spawn)) @@ -157,7 +156,6 @@ async fn serde_uds() -> anyhow::Result<()> { transport .take(1) .filter_map(|r| async { r.ok() }) - .map(map_transport_to_server) .map(BaseChannel::with_defaults) .execute(Server.serve()) .map(|channel| channel.for_each(spawn)) @@ -287,7 +285,7 @@ async fn counter() -> anyhow::Result<()> { struct CountService(u32); impl Counter for &mut CountService { - type Context = ServerContext; + type Context = SharedContext; async fn count(self, _: &mut Self::Context) -> u32 { self.0 += 1; self.0 From 7989bc084b78e41c3a59cc3eef2135d8ee22c63d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Wed, 26 Nov 2025 16:09:34 +0100 Subject: [PATCH 14/26] make clientContext generic as well --- example-service/src/client.rs | 9 +- plugins/src/lib.rs | 43 +++--- tarpc/examples/compression.rs | 4 +- tarpc/examples/custom_transport.rs | 6 +- tarpc/examples/pubsub.rs | 46 +++--- tarpc/examples/readme.rs | 6 +- tarpc/examples/tls_over_tcp.rs | 6 +- tarpc/examples/tracing.rs | 32 +++-- tarpc/src/client.rs | 131 ++++++++++-------- tarpc/src/client/stub.rs | 21 +-- tarpc/src/client/stub/load_balance.rs | 14 +- tarpc/src/client/stub/mock.rs | 4 +- tarpc/src/client/stub/retry.rs | 4 +- tarpc/src/context.rs | 68 --------- tarpc/src/lib.rs | 9 +- tarpc/src/server.rs | 24 ++-- tarpc/src/server/incoming.rs | 6 +- tarpc/src/transport/channel.rs | 109 +-------------- .../compile_fail/must_use_request_dispatch.rs | 4 +- .../must_use_request_dispatch.stderr | 6 +- tarpc/tests/dataservice.rs | 9 +- tarpc/tests/service_functional.rs | 44 +++--- 22 files changed, 234 insertions(+), 371 deletions(-) diff --git a/example-service/src/client.rs b/example-service/src/client.rs index 40402867f..b8ff22c97 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -8,8 +8,7 @@ use clap::Parser; use futures::{SinkExt, future}; use service::{WorldClient, init_tracing}; use std::{net::SocketAddr, time::Duration}; -use tarpc::context::ClientContext; -use tarpc::transport::channel::{map_transport_to_client}; +use tarpc::context::{SharedContext}; use tarpc::{client, tokio_serde::formats::Json}; use tokio::time::sleep; use tracing::Instrument; @@ -32,15 +31,15 @@ async fn main() -> anyhow::Result<()> { let mut transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default); transport.config_mut().max_frame_length(usize::MAX); - let transport = map_transport_to_client(transport.await?); + let transport = transport.await?; // WorldClient is generated by the service attribute. It has a constructor `new` that takes a // config and any Transport as input. let client = WorldClient::new(client::Config::default(), transport).spawn(); let hello = async move { - let mut context = ClientContext::current(); - let mut context2 = ClientContext::current(); + let mut context = SharedContext::current(); + let mut context2 = SharedContext::current(); // Send the request twice, just to be safe! ;) tokio::select! { diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 250ffff04..cf107d0ad 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -395,7 +395,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// let resp = CalculatorResponse::Add(12); /// /// // This could be any transport. -/// let (client_side, server_side) = transport::channel::unbounded_for_client_server_context(); +/// let (client_side, server_side) = transport::channel::unbounded(); /// /// // A client can be made like so: /// let client = CalculatorClient::new(client::Config::default(), client_side); @@ -583,11 +583,11 @@ impl ServiceGenerator<'_> { } #[doc = #stub_doc] - #vis trait #client_stub_ident: ::tarpc::client::stub::Stub { + #vis trait #client_stub_ident: ::tarpc::client::stub::Stub { } - impl #client_stub_ident for S - where S: ::tarpc::client::stub::Stub + impl #client_stub_ident for S + where S: ::tarpc::client::stub::Stub { } } @@ -717,12 +717,19 @@ impl ServiceGenerator<'_> { quote! { #[allow(unused)] - #[derive(Clone, Debug)] + #[derive(Debug)] /// The client stub that makes RPC calls to the server. All request methods return /// [Futures](::core::future::Future). #vis struct #client_ident< - Stub = ::tarpc::client::Channel<#request_ident, #response_ident> - >(Stub); + ClientCtx, + Stub = ::tarpc::client::Channel<#request_ident, #response_ident, ClientCtx> + >(Stub, ::std::marker::PhantomData); + + impl ::std::clone::Clone for #client_ident { + fn clone(&self) -> Self { + Self(self.0.clone(), ::std::marker::PhantomData) + } + } } } @@ -736,32 +743,33 @@ impl ServiceGenerator<'_> { } = self; quote! { - impl #client_ident { + impl #client_ident { /// Returns a new client stub that sends requests over the given transport. #vis fn new(config: ::tarpc::client::Config, transport: T) -> ::tarpc::client::NewClient< Self, - ::tarpc::client::RequestDispatch<#request_ident, #response_ident, T> + ::tarpc::client::RequestDispatch<#request_ident, #response_ident, ClientCtx, T> > where - T: ::tarpc::Transport<::tarpc::ClientMessage<::tarpc::context::ClientContext, #request_ident>, ::tarpc::Response<::tarpc::context::ClientContext, #response_ident>> + T: ::tarpc::Transport<::tarpc::ClientMessage, ::tarpc::Response> { let new_client = ::tarpc::client::new(config, transport); ::tarpc::client::NewClient { - client: #client_ident(new_client.client), + client: #client_ident(new_client.client, ::std::marker::PhantomData), dispatch: new_client.dispatch, } } } - impl ::core::convert::From for #client_ident + impl ::core::convert::From for #client_ident where Stub: ::tarpc::client::stub::Stub< Req = #request_ident, - Resp = #response_ident> + Resp = #response_ident, + ClientCtx = ClientCtx> { /// Returns a new client stub that sends requests over the given transport. fn from(stub: Stub) -> Self { - #client_ident(stub) + #client_ident::(stub, ::std::marker::PhantomData) } } @@ -784,15 +792,16 @@ impl ServiceGenerator<'_> { } = self; quote! { - impl #client_ident + impl #client_ident where Stub: ::tarpc::client::stub::Stub< Req = #request_ident, - Resp = #response_ident> + Resp = #response_ident, + ClientCtx = ClientCtx> { #( #[allow(unused)] #( #method_attrs )* - #vis fn #method_idents<'a>(&'a self, ctx: &'a mut Stub::ServerCtx, #( #args ),*) + #vis fn #method_idents<'a>(&'a self, ctx: &'a mut Stub::ClientCtx, #( #args ),*) -> impl ::core::future::Future> + '_ { let request = #request_ident::#camel_case_idents { #( #arg_pats ),* }; let resp = self.0.call(ctx, request); diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index 6a1440bd2..f201521ad 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -9,7 +9,6 @@ use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt, prelude::*}; use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; use std::{io, io::Read, io::Write}; -use tarpc::transport::channel::{map_transport_to_client}; use tarpc::{ client, context, serde_transport::tcp, @@ -136,13 +135,12 @@ async fn main() -> anyhow::Result<()> { let transport = tcp::connect(addr, Bincode::default).await?; let transport = add_compression(transport); - let transport = map_transport_to_client(transport); let client = WorldClient::new(client::Config::default(), transport).spawn(); println!( "{}", client - .hello(&mut context::ClientContext::current(), "friend".into()) + .hello(&mut context::SharedContext::current(), "friend".into()) .await? ); Ok(()) diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index 92c723b4d..c9eb871ea 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -6,12 +6,11 @@ use console_subscriber::Server; use futures::prelude::*; -use tarpc::context::{ClientContext, SharedContext}; +use tarpc::context::{SharedContext}; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; -use tarpc::transport::channel::{map_transport_to_client}; use tokio::net::{UnixListener, UnixStream}; #[tarpc::service] @@ -52,10 +51,9 @@ async fn main() -> anyhow::Result<()> { let conn = UnixStream::connect(bind_addr).await?; let transport = transport::new(codec_builder.new_framed(conn), Bincode::default()); - let transport = map_transport_to_client(transport); PingServiceClient::new(Default::default(), transport) .spawn() - .ping(&mut ClientContext::current()) + .ping(&mut SharedContext::current()) .await?; Ok(()) diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index fbe19078a..07a93becf 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -47,9 +47,11 @@ use std::{ net::SocketAddr, sync::{Arc, Mutex, RwLock}, }; +use std::ops::Shl; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; use subscriber::Subscriber as _; -use tarpc::context::{ClientContext, SharedContext}; -use tarpc::transport::channel::{map_transport_to_client}; +use tarpc::context::{ExtractContext, SharedContext}; use tarpc::{ ClientMessage, client, context, serde_transport::tcp, @@ -135,10 +137,19 @@ struct Subscription { topics: Vec, } -#[derive(Clone, Debug)] -struct Publisher { +#[derive(Debug)] +struct Publisher { clients: Arc>>, - subscriptions: Arc>>>, + subscriptions: Arc>>>>, +} + +impl Clone for Publisher { + fn clone(&self) -> Self { + Publisher { + clients: self.clients.clone(), + subscriptions: self.subscriptions.clone(), + } + } } struct PublisherAddrs { @@ -150,7 +161,7 @@ async fn spawn(fut: impl Future + Send + 'static) { tokio::spawn(fut); } -impl Publisher { +impl Publisher where ClientCtx: ExtractContext + From + Serialize + DeserializeOwned + Send + Sync + 'static { // TODO: Remove serde bounds here async fn start(self) -> io::Result { let mut connecting_publishers = tcp::listen("localhost:0", Json::default).await?; @@ -187,7 +198,6 @@ impl Publisher { tokio::spawn(async move { while let Some(conn) = connecting_subscribers.next().await { let subscriber_addr = conn.peer_addr().unwrap(); - let conn = map_transport_to_client(conn); let tarpc::client::NewClient { client: subscriber, dispatch, @@ -211,11 +221,11 @@ impl Publisher { async fn initialize_subscription( &mut self, subscriber_addr: SocketAddr, - subscriber: subscriber::SubscriberClient, + subscriber: subscriber::SubscriberClient, ) { // Populate the topics if let Ok(topics) = subscriber - .topics(&mut context::ClientContext::current()) + .topics(&mut ClientCtx::from(context::SharedContext::current())) .await { self.clients.lock().unwrap().insert( @@ -269,8 +279,8 @@ impl Publisher { } } -impl publisher::Publisher for Publisher { - type Context = SharedContext; +impl publisher::Publisher for Publisher where ClientCtx: ExtractContext + From + Send + Sync + 'static { + type Context = ClientCtx; async fn publish(self, _: &mut Self::Context, topic: String, message: String) { info!("received message to publish."); let mut subscribers = match self.subscriptions.read().unwrap().get(&topic) { @@ -283,7 +293,7 @@ impl publisher::Publisher for Publisher { publications.push(async { client .receive( - &mut context::ClientContext::current(), + &mut ClientCtx::from(context::SharedContext::current()), topic.clone(), message.clone(), ) @@ -333,7 +343,7 @@ pub fn init_tracing( async fn main() -> anyhow::Result<()> { let tracer_provider = init_tracing("Pub/Sub")?; - let addrs = Publisher { + let addrs = Publisher:: { clients: Arc::new(Mutex::new(HashMap::new())), subscriptions: Arc::new(RwLock::new(HashMap::new())), } @@ -354,13 +364,13 @@ async fn main() -> anyhow::Result<()> { let publisher = publisher::PublisherClient::new( client::Config::default(), - map_transport_to_client(tcp::connect(addrs.publisher, Json::default).await?), + tcp::connect(addrs.publisher, Json::default).await?, ) .spawn(); publisher .publish( - &mut ClientContext::current(), + &mut SharedContext::current(), "calculus".into(), "sqrt(2)".into(), ) @@ -368,7 +378,7 @@ async fn main() -> anyhow::Result<()> { publisher .publish( - &mut ClientContext::current(), + &mut SharedContext::current(), "cool shorts".into(), "hello to all".into(), ) @@ -376,7 +386,7 @@ async fn main() -> anyhow::Result<()> { publisher .publish( - &mut ClientContext::current(), + &mut SharedContext::current(), "history".into(), "napoleon".to_string(), ) @@ -386,7 +396,7 @@ async fn main() -> anyhow::Result<()> { publisher .publish( - &mut ClientContext::current(), + &mut SharedContext::current(), "cool shorts".into(), "hello to who?".into(), ) diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index db93d2e74..359b4af8b 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -5,7 +5,7 @@ // https://opensource.org/licenses/MIT. use futures::prelude::*; -use tarpc::context::{ClientContext, SharedContext}; +use tarpc::context::{SharedContext}; use tarpc::{ ClientMessage, client, context, server::{self, Channel}, @@ -38,7 +38,7 @@ async fn spawn(fut: impl Future + Send + 'static) { #[tokio::main] async fn main() -> anyhow::Result<()> { let (client_transport, server_transport) = - transport::channel::unbounded_for_client_server_context(); + transport::channel::unbounded(); let server = server::BaseChannel::with_defaults(server_transport); tokio::spawn(server.execute(HelloServer.serve()).for_each(spawn)); @@ -51,7 +51,7 @@ async fn main() -> anyhow::Result<()> { // args as defined, with the addition of a Context, which is always the first arg. The Context // specifies a deadline and trace information which can be helpful in debugging requests. let hello = client - .hello(&mut context::ClientContext::current(), "Stim".to_string()) + .hello(&mut context::SharedContext::current(), "Stim".to_string()) .await?; println!("{hello}"); diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index 0e00cdca8..c203bf0b8 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -10,12 +10,11 @@ use std::io::{self, BufReader, Cursor}; use std::net::{IpAddr, Ipv4Addr}; use std::sync::Arc; -use tarpc::context::{ClientContext, SharedContext}; +use tarpc::context::{SharedContext}; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; -use tarpc::transport::channel::{map_transport_to_client}; use tokio::net::TcpListener; use tokio::net::TcpStream; use tokio_rustls::rustls::{ @@ -145,10 +144,9 @@ async fn main() -> anyhow::Result<()> { let stream = connector.connect(domain, stream).await?; let transport = transport::new(codec_builder.new_framed(stream), Bincode::default()); - let transport = map_transport_to_client(transport); let answer = PingServiceClient::new(Default::default(), transport) .spawn() - .ping(&mut ClientContext::current()) + .ping(&mut SharedContext::current()) .await?; println!("ping answer: {answer}"); diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index b69e0c1a0..525a16a47 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -19,8 +19,8 @@ use std::{ atomic::{AtomicBool, Ordering}, }, }; -use tarpc::context::{ClientContext, SharedContext}; -use tarpc::transport::channel::{map_transport_to_client}; +use std::marker::PhantomData; +use tarpc::context::{ExtractContext, SharedContext}; use tarpc::{ ClientMessage, RequestName, Response, ServerError, Transport, client::{ @@ -65,18 +65,20 @@ impl AddService for AddServer { } #[derive(Clone)] -struct DoubleServer { - add_client: add::AddClient, +struct DoubleServer { + add_client: add::AddClient, + ghost: PhantomData } -impl DoubleService for DoubleServer +impl DoubleService for DoubleServer where - Stub: AddStub + Clone + Send + Sync + 'static, + Stub: AddStub + Clone + Send + Sync + 'static, + ClientCtx: From + Send + Sync + 'static { type Context = SharedContext; async fn double(self, _: &mut Self::Context, x: i32) -> Result { self.add_client - .add(&mut context::ClientContext::current(), x, x) + .add(&mut ClientCtx::from(context::SharedContext::current()), x, x) .await .map_err(|e| e.to_string()) } @@ -127,18 +129,19 @@ where Ok((listener, addr)) } -fn make_stub( - backends: [impl Transport>, Response> +fn make_stub( + backends: [impl Transport>, Response> + Send + Sync + 'static; N], ) -> retry::Retry< impl Fn(&Result, u32) -> bool + Clone, - load_balance::RoundRobin, Resp>>, + load_balance::RoundRobin, Resp, ClientCtx>>, > where Req: RequestName + Send + Sync + 'static, Resp: Send + Sync + 'static, + ClientCtx: ExtractContext + From + Send + Sync + 'static { let stub = load_balance::RoundRobin::new( backends @@ -184,8 +187,8 @@ async fn main() -> anyhow::Result<()> { tokio::spawn(spawn_incoming(add_server.execute(server))); let add_client = add::AddClient::from(make_stub([ - map_transport_to_client(tarpc::serde_transport::tcp::connect(addr1, Json::default).await?), - map_transport_to_client(tarpc::serde_transport::tcp::connect(addr2, Json::default).await?), + tarpc::serde_transport::tcp::connect(addr1, Json::default).await?, + tarpc::serde_transport::tcp::connect(addr2, Json::default).await?, ])); let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) @@ -193,11 +196,10 @@ async fn main() -> anyhow::Result<()> { .filter_map(|r| future::ready(r.ok())); let addr = double_listener.get_ref().local_addr(); let double_server = double_listener.map(BaseChannel::with_defaults).take(1); - let server = DoubleServer { add_client }.serve(); + let server = DoubleServer::<_, SharedContext> { add_client, ghost: PhantomData }.serve(); tokio::spawn(spawn_incoming(double_server.execute(server))); let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; - let to_double_server = map_transport_to_client(to_double_server); let double_client = double::DoubleClient::new(client::Config::default(), to_double_server).spawn(); @@ -205,7 +207,7 @@ async fn main() -> anyhow::Result<()> { tracing::info!( "{:?}", double_client - .double(&mut context::ClientContext::current(), 1) + .double(&mut context::SharedContext::current(), 1) .await? ); } diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index b6763a9b2..40ba7e461 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -9,7 +9,7 @@ mod in_flight_requests; pub mod stub; -use crate::context::{ClientContext, ExtractContext, SharedContext}; +use crate::context::{ExtractContext, SharedContext}; use crate::{ ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport, cancellations::{CanceledRequests, RequestCancellation, cancellations}, @@ -30,6 +30,7 @@ use std::{ }, time::SystemTime, }; +use std::marker::PhantomData; use tokio::sync::{mpsc, oneshot}; use tracing::Span; @@ -96,27 +97,32 @@ const _CHECK_USIZE: () = assert!( /// Handles communication from the client to request dispatch. #[derive(Debug)] -pub struct Channel { +pub struct Channel { to_dispatch: mpsc::Sender>, /// Channel to send a cancel message to the dispatcher. cancellation: RequestCancellation, /// The ID to use for the next request to stage. next_request_id: Arc, + + ///TODO: Document + ghost: PhantomData } -impl Clone for Channel { +impl Clone for Channel { fn clone(&self) -> Self { Self { to_dispatch: self.to_dispatch.clone(), cancellation: self.cancellation.clone(), next_request_id: self.next_request_id.clone(), + ghost: PhantomData } } } -impl Channel +impl Channel where Req: RequestName, + ClientCtx: ExtractContext { /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that /// resolves to the response. @@ -129,9 +135,9 @@ where otel.kind = "client", otel.name = %request.name()) )] - pub async fn call>( + pub async fn call( &self, - ctx: &mut Ctx, + ctx: &mut ClientCtx, request: Req, ) -> Result { let span = Span::current(); @@ -245,12 +251,12 @@ impl Drop for ResponseGuard<'_, Resp> { /// Returns a channel and dispatcher that manages the lifecycle of requests initiated by the /// channel. -pub fn new( +pub fn new( config: Config, transport: C, -) -> NewClient, RequestDispatch> +) -> NewClient, RequestDispatch> where - C: Transport, Response>, + C: Transport, Response>, { let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer); let (cancellation, canceled_requests) = cancellations(); @@ -260,6 +266,7 @@ where to_dispatch, cancellation, next_request_id: Arc::new(AtomicUsize::new(0)), + ghost: PhantomData }, dispatch: RequestDispatch { config, @@ -268,6 +275,7 @@ where in_flight_requests: InFlightRequests::default(), pending_requests, terminal_error: None, + ghost: PhantomData }, } } @@ -277,7 +285,7 @@ where #[must_use] #[pin_project()] #[derive(Debug)] -pub struct RequestDispatch { +pub struct RequestDispatch { /// Writes requests to the wire and reads responses off the wire. #[pin] transport: Fuse, @@ -294,11 +302,14 @@ pub struct RequestDispatch { /// RequestDispatch::poll, which relies on downcasting the Any to a concrete error type /// determined within the poll function. terminal_error: Option>, + + ghost: PhantomData, } -impl RequestDispatch +impl RequestDispatch where - C: Transport, Response>, + C: Transport, Response>, + ClientCtx: ExtractContext + From { fn in_flight_requests<'a>( self: &'a mut Pin<&mut Self>, @@ -321,7 +332,7 @@ where fn start_send( self: &mut Pin<&mut Self>, - message: ClientMessage, + message: ClientMessage, ) -> Result<(), C::Error> { self.transport_pin_mut().start_send(message) } @@ -531,7 +542,7 @@ where let request = ClientMessage::Request(Request { id: request_id, message: request, - context: ClientContext::new(ctx), + context: ctx.into(), }); self.in_flight_requests() @@ -581,10 +592,10 @@ where } /// Sends a server response to the client task that initiated the associated request. - fn complete(mut self: Pin<&mut Self>, response: Response) -> bool { + fn complete(mut self: Pin<&mut Self>, response: Response) -> bool { if let Some(span) = self.in_flight_requests().complete_request( response.request_id, - response.message.map_err(RpcError::Server).map(|m| (response.context.shared_context, m)), + response.message.map_err(RpcError::Server).map(|m| (response.context.extract(), m)), ) { let _entered = span.enter(); tracing::debug!("ReceiveResponse"); @@ -659,9 +670,10 @@ where } } -impl Future for RequestDispatch +impl Future for RequestDispatch where - C: Transport, Response>, + C: Transport, Response>, + ClientCtx: ExtractContext + From { type Output = Result<(), ChannelError>; @@ -704,7 +716,7 @@ mod tests { use super::{ Channel, DispatchRequest, RequestDispatch, ResponseGuard, RpcError, cancellations, }; - use crate::context::{ClientContext, SharedContext}; + use crate::context::{SharedContext}; use crate::{ ChannelError, ClientMessage, Response, client::{Config, in_flight_requests::InFlightRequests}, @@ -735,7 +747,7 @@ mod tests { let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); - let context = ClientContext::current(); + let context = SharedContext::current(); dispatch .in_flight_requests @@ -750,7 +762,7 @@ mod tests { server_channel .send(Response { request_id: 0, - context: ClientContext::current(), + context: SharedContext::current(), message: Ok("Resp".into()), }) .await @@ -796,7 +808,7 @@ mod tests { #[tokio::test] async fn stage_request() { - let (mut dispatch, mut channel, _server_channel) = set_up(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); @@ -826,7 +838,7 @@ mod tests { &mut server_channel, Response { request_id: 0, - context: ClientContext::current(), + context: SharedContext::current(), message: Ok("hello".into()), }, ) @@ -837,7 +849,7 @@ mod tests { #[allow(unstable_name_collisions)] #[tokio::test] async fn stage_request_response_future_dropped_is_canceled_before_sending() { - let (mut dispatch, mut channel, _server_channel) = set_up(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); @@ -853,7 +865,7 @@ mod tests { #[allow(unstable_name_collisions)] #[tokio::test] async fn stage_request_response_future_dropped_is_canceled_after_sending() { - let (mut dispatch, mut channel, _server_channel) = set_up(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); @@ -874,7 +886,7 @@ mod tests { #[tokio::test] async fn stage_request_response_closed_skipped() { - let (mut dispatch, mut channel, _server_channel) = set_up(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); @@ -890,7 +902,7 @@ mod tests { #[tokio::test] async fn test_permit_before_transport_error() { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); - let (mut dispatch, mut channel, mut cx) = set_up_always_err(TransportError::Flush); + let (mut dispatch, mut channel, mut cx) = set_up_always_err::(TransportError::Flush); let (tx, mut rx) = oneshot::channel(); // reserve succeeds let permit = reserve_for_send(&mut channel, tx, &mut rx).await; @@ -907,11 +919,11 @@ mod tests { #[tokio::test] async fn test_shutdown() { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); - let (dispatch, channel, _server_channel) = set_up(); + let (dispatch, channel, _server_channel) = set_up::(); drop(dispatch); // error on send let resp = channel - .call(&mut ClientContext::current(), "hi".to_string()) + .call(&mut SharedContext::current(), "hi".to_string()) .await; assert_matches!(resp, Err(RpcError::Shutdown)); } @@ -919,7 +931,7 @@ mod tests { #[tokio::test] async fn test_transport_error_write() { let cause = TransportError::Write; - let (mut dispatch, mut channel, mut cx) = set_up_always_err(cause); + let (mut dispatch, mut channel, mut cx) = set_up_always_err::(cause); let (tx, mut rx) = oneshot::channel(); let resp = send_request(&mut channel, "hi", tx, &mut rx).await; @@ -942,7 +954,7 @@ mod tests { #[tokio::test] async fn test_transport_error_read() { let cause = TransportError::Read; - let (mut dispatch, mut channel, mut cx) = set_up_always_err(cause); + let (mut dispatch, mut channel, mut cx) = set_up_always_err::(cause); let (tx, mut rx) = oneshot::channel(); let resp = send_request(&mut channel, "hi", tx, &mut rx).await; assert_eq!( @@ -959,7 +971,7 @@ mod tests { #[tokio::test] async fn test_transport_error_ready() { let cause = TransportError::Ready; - let (mut dispatch, _, mut cx) = set_up_always_err(cause); + let (mut dispatch, _, mut cx) = set_up_always_err::(cause); assert_eq!( dispatch.as_mut().poll(&mut cx), Poll::Ready(Err(ChannelError::Ready(Arc::new(cause)))) @@ -969,7 +981,7 @@ mod tests { #[tokio::test] async fn test_transport_error_flush() { let cause = TransportError::Flush; - let (mut dispatch, _, mut cx) = set_up_always_err(cause); + let (mut dispatch, _, mut cx) = set_up_always_err::(cause); assert_eq!( dispatch.as_mut().poll(&mut cx), Poll::Ready(Err(ChannelError::Flush(Arc::new(cause)))) @@ -979,7 +991,7 @@ mod tests { #[tokio::test] async fn test_transport_error_close() { let cause = TransportError::Close; - let (mut dispatch, channel, mut cx) = set_up_always_err(cause); + let (mut dispatch, channel, mut cx) = set_up_always_err::(cause); drop(channel); assert_eq!( dispatch.as_mut().poll(&mut cx), @@ -988,34 +1000,36 @@ mod tests { } /// Sets up a RequestDispatch with a transport that always errors. - fn set_up_always_err( + fn set_up_always_err( cause: TransportError, ) -> ( - Pin>>>, - Channel, + Pin>>>, + Channel, Context<'static>, ) { let (to_dispatch, pending_requests) = mpsc::channel(1); let (cancellation, canceled_requests) = cancellations(); - let transport: AlwaysErrorTransport = AlwaysErrorTransport(cause, PhantomData); - let dispatch = Box::pin(RequestDispatch:: { + let transport: AlwaysErrorTransport = AlwaysErrorTransport(cause, PhantomData); + let dispatch = Box::pin(RequestDispatch:: { transport: transport.fuse(), pending_requests, canceled_requests, in_flight_requests: InFlightRequests::default(), config: Config::default(), terminal_error: None, + ghost: PhantomData }); let channel = Channel { to_dispatch, cancellation, next_request_id: Arc::new(AtomicUsize::new(0)), + ghost: PhantomData }; let cx = Context::from_waker(noop_waker_ref()); (dispatch, channel, cx) } - struct AlwaysErrorTransport(TransportError, PhantomData); + struct AlwaysErrorTransport(TransportError, PhantomData<( I, ClientCtx)>); #[derive(Debug, Error, PartialEq, Eq, Clone, Copy)] enum TransportError { @@ -1032,7 +1046,7 @@ mod tests { } } - impl Sink for AlwaysErrorTransport { + impl Sink for AlwaysErrorTransport { type Error = TransportError; fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { match self.0 { @@ -1064,8 +1078,8 @@ mod tests { } } - impl Stream for AlwaysErrorTransport { - type Item = Result, TransportError>; + impl Stream for AlwaysErrorTransport { + type Item = Result, TransportError>; fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { if matches!(self.0, TransportError::Read) { Poll::Ready(Some(Err(self.0))) @@ -1075,21 +1089,22 @@ mod tests { } } - fn set_up() -> ( + fn set_up() -> ( Pin< Box< RequestDispatch< String, String, + ClientCtx, UnboundedChannel< - Response, - ClientMessage, + Response, + ClientMessage, >, >, >, >, - Channel, - UnboundedChannel, Response>, + Channel, + UnboundedChannel, Response>, ) { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); @@ -1097,26 +1112,28 @@ mod tests { let (cancellation, canceled_requests) = cancellations(); let (client_channel, server_channel) = transport::channel::unbounded(); - let dispatch = RequestDispatch:: { + let dispatch = RequestDispatch:: { transport: client_channel.fuse(), pending_requests, canceled_requests, in_flight_requests: InFlightRequests::default(), config: Config::default(), terminal_error: None, + ghost: PhantomData }; let channel = Channel { to_dispatch, cancellation, next_request_id: Arc::new(AtomicUsize::new(0)), + ghost: PhantomData }; (Box::pin(dispatch), channel, server_channel) } - async fn send_request<'a>( - channel: &'a mut Channel, + async fn send_request<'a, ClientCtx>( + channel: &'a mut Channel, request: &str, response_completion: oneshot::Sender>, response: &'a mut oneshot::Receiver>, @@ -1140,8 +1157,8 @@ mod tests { response_guard } - async fn reserve_for_send<'a>( - channel: &'a mut Channel, + async fn reserve_for_send<'a, ClientCtx>( + channel: &'a mut Channel, response_completion: oneshot::Sender>, response: &'a mut oneshot::Receiver>, ) -> impl FnOnce(&str) -> ResponseGuard<'a, String> { @@ -1166,12 +1183,12 @@ mod tests { } } - async fn send_response( + async fn send_response( channel: &mut UnboundedChannel< - ClientMessage, - Response, + ClientMessage, + Response, >, - response: Response, + response: Response, ) { channel.send(response).await.unwrap(); } diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index 9989f0577..992f6d611 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -5,7 +5,7 @@ use crate::{ client::{Channel, RpcError}, server::Serve, }; -use crate::context::{ClientContext, SharedContext}; +use crate::context::{ExtractContext, SharedContext}; pub mod load_balance; pub mod retry; @@ -24,27 +24,28 @@ pub trait Stub { type Resp; ///TODO: document - type ServerCtx; + type ClientCtx; /// Calls a remote service. async fn call( &self, - ctx: &mut Self::ServerCtx, + ctx: &mut Self::ClientCtx, request: Self::Req, ) -> Result; } -impl Stub for Channel +impl Stub for Channel where Req: RequestName, + ClientCtx: ExtractContext { type Req = Req; type Resp = Resp; - type ServerCtx = ClientContext; + type ClientCtx = ClientCtx; async fn call( &self, - ctx: &mut Self::ServerCtx, + ctx: &mut Self::ClientCtx, request: Req, ) -> Result { Self::call(self, ctx, request).await @@ -57,13 +58,13 @@ where { type Req = S::Req; type Resp = S::Resp; - type ServerCtx = ClientContext; + type ClientCtx = SharedContext; async fn call( &self, - ctx: &mut ClientContext, + ctx: &mut Self::ClientCtx, req: Self::Req, ) -> Result { - let mut server_ctx = ctx.shared_context.clone(); + let mut server_ctx = ctx.clone(); let res = self .clone() @@ -71,7 +72,7 @@ where .await .map_err(RpcError::Server); - ctx.shared_context = server_ctx; + *ctx = server_ctx; res } diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs index 5b319c6c8..60efafc91 100644 --- a/tarpc/src/client/stub/load_balance.rs +++ b/tarpc/src/client/stub/load_balance.rs @@ -16,11 +16,11 @@ mod round_robin { { type Req = Stub::Req; type Resp = Stub::Resp; - type ServerCtx = Stub::ServerCtx; + type ClientCtx = Stub::ClientCtx; async fn call( &self, - ctx: &mut Self::ServerCtx, + ctx: &mut Self::ClientCtx, request: Self::Req, ) -> Result { let next = self.stubs.next(); @@ -115,11 +115,11 @@ mod consistent_hash { { type Req = Stub::Req; type Resp = Stub::Resp; - type ServerCtx = Stub::ServerCtx; + type ClientCtx = Stub::ClientCtx; async fn call( &self, - ctx: &mut Self::ServerCtx, + ctx: &mut Self::ClientCtx, request: Self::Req, ) -> Result { let index = usize::try_from(self.hasher.hash_one(&request) % self.stubs_len).expect( @@ -201,17 +201,17 @@ mod consistent_hash { for _ in 0..2 { let resp = stub - .call(&mut context::ClientContext::current(), 'a') + .call(&mut context::SharedContext::current(), 'a') .await?; assert_eq!(resp, 1); let resp = stub - .call(&mut context::ClientContext::current(), 'b') + .call(&mut context::SharedContext::current(), 'b') .await?; assert_eq!(resp, 2); let resp = stub - .call(&mut context::ClientContext::current(), 'c') + .call(&mut context::SharedContext::current(), 'c') .await?; assert_eq!(resp, 3); } diff --git a/tarpc/src/client/stub/mock.rs b/tarpc/src/client/stub/mock.rs index 9a22d101e..577ef5362 100644 --- a/tarpc/src/client/stub/mock.rs +++ b/tarpc/src/client/stub/mock.rs @@ -31,11 +31,11 @@ where { type Req = Req; type Resp = Resp; - type ServerCtx = ServerCtx; + type ClientCtx = ServerCtx; async fn call( &self, - _: &mut Self::ServerCtx, + _: &mut Self::ClientCtx, request: Self::Req, ) -> Result { self.responses diff --git a/tarpc/src/client/stub/retry.rs b/tarpc/src/client/stub/retry.rs index 2cf950aed..5499f60e4 100644 --- a/tarpc/src/client/stub/retry.rs +++ b/tarpc/src/client/stub/retry.rs @@ -14,11 +14,11 @@ where { type Req = Req; type Resp = Stub::Resp; - type ServerCtx = Stub::ServerCtx; + type ClientCtx = Stub::ClientCtx; async fn call( &self, - ctx: &mut Self::ServerCtx, + ctx: &mut Self::ClientCtx, request: Self::Req, ) -> Result { let request = Arc::new(request); diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index 5cc9389f1..5ca5b8256 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -56,74 +56,6 @@ impl ExtractContext for T where T: Clone { } } -/// Request context that carries request-scoped client side information like deadlines and trace information -/// as well as any server side extensions defined by the transport, hooks and stubs. -/// The shared part of the context is sent from client to server, while the client side extensions are only seen on the client side. -/// -/// The context should not be stored directly in a stub implementation, because the context will -/// be different for each request in scope. -#[derive(Debug)] -pub struct ClientContext { - /// Shared context sent from client to server which contains information used by both sides. - pub shared_context: SharedContext, - - /// Client side extensions that are not seen by the server - /// XXX, YYY, and ZZZ can use this to store per-request data, and communicate with eachother. - /// Note that this is NOT sent to the server, and they will always see an empty map here. - pub client_context: anymap3::Map, -} - -impl ClientContext { - /// Creates a new ServerContext from the given SharedContext with no extensions. - pub fn new(shared_context: SharedContext) -> Self { - Self { - shared_context, - client_context: anymap3::Map::new(), - } - } - - /// Creates a new ServerContext for the current shared context with no extensions. - pub fn current() -> Self { - Self::new(SharedContext::current()) - } -} - -impl ExtractContext for ClientContext { - fn extract(&self) -> SharedContext { - self.shared_context.clone() - } - - fn update(&mut self, value: SharedContext) { - self.shared_context = value - } -} - -impl ExtractContext for ServerContext { - fn extract(&self) -> SharedContext { - self.shared_context.clone() - } - - fn update(&mut self, value: SharedContext) { - self.shared_context = value - } -} - - - -impl Deref for ClientContext { - type Target = SharedContext; - - fn deref(&self) -> &Self::Target { - &self.shared_context - } -} - -impl DerefMut for ClientContext { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.shared_context - } -} - #[cfg(feature = "serde1")] mod absolute_to_relative_time { pub use serde::{Deserialize, Deserializer, Serialize, Serializer}; diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index e0869d9f6..fc79e3056 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -145,7 +145,7 @@ //! # use tarpc::{ //! # ClientMessage, //! # client, context, -//! # context::{ClientContext, SharedContext}, +//! # context::{SharedContext}, //! # transport::channel, //! # server::{self, Channel}, //! # }; @@ -172,7 +172,8 @@ //! # #[cfg(feature = "tokio1")] //! #[tokio::main] //! async fn main() -> anyhow::Result<()> { -//! let (client_transport, server_transport) = channel::unbounded_for_client_server_context(); +//! use futures::future::Shared; +//! let (client_transport, server_transport) = channel::unbounded(); //! let server = server::BaseChannel::with_defaults(server_transport); //! tokio::spawn( //! server.execute(HelloServer.serve()) @@ -183,12 +184,12 @@ //! //! // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` //! // that takes a config and any Transport as input. -//! let mut client = WorldClient::new(client::Config::default(), client_transport).spawn(); +//! let mut client = WorldClient::::new(client::Config::default(), client_transport).spawn(); //! //! // The client has an RPC method for each RPC defined in the annotated trait. It takes the same //! // args as defined, with the addition of a Context, which is always the first arg. The Context //! // specifies a deadline and trace information which can be helpful in debugging requests. -//! let mut context = context::ClientContext::current(); +//! let mut context = context::SharedContext::current(); //! let hello = client.hello(&mut context, "Stim".to_string()).await?; //! //! println!("{hello}"); diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index e6c395836..7d345a203 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -363,7 +363,7 @@ where /// use tarpc::{ /// ClientMessage, /// context, - /// context::{ClientContext, SharedContext}, + /// context::{SharedContext}, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, serve}, /// transport, @@ -372,7 +372,7 @@ where /// /// #[tokio::main] /// async fn main() { - /// let (tx, rx) = transport::channel::unbounded_for_client_server_context(); + /// let (tx, rx) = transport::channel::unbounded(); /// let server = BaseChannel::new(server::Config::default(), rx); /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); /// tokio::spawn(dispatch); @@ -383,7 +383,7 @@ where /// tokio::spawn(request.execute(serve(|_, i| async move { Ok(i + 1) }.boxed()))); /// } /// }); - /// let mut context = context::ClientContext::current(); + /// let mut context = context::SharedContext::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` @@ -407,7 +407,7 @@ where /// # Example /// /// ```rust - /// use tarpc::{ClientMessage, context, client, server::{self, BaseChannel, Channel, serve}, transport, context::{ClientContext, SharedContext}}; + /// use tarpc::{ClientMessage, context, client, server::{self, BaseChannel, Channel, serve}, transport, context::{SharedContext}}; /// use futures::prelude::*; /// use tracing_subscriber::prelude::*; /// @@ -416,7 +416,7 @@ where /// # #[cfg(feature = "tokio1")] /// #[tokio::main] /// async fn main() { - /// let (tx, rx) = transport::channel::unbounded_for_client_server_context(); + /// let (tx, rx) = transport::channel::unbounded(); /// let client = client::new(client::Config::default(), tx).spawn(); /// let channel = BaseChannel::with_defaults(rx); /// tokio::spawn( @@ -424,7 +424,7 @@ where /// .for_each(|response| async move { /// tokio::spawn(response); /// }.boxed())); - /// let mut context = context::ClientContext::current(); + /// let mut context = context::SharedContext::current(); /// assert_eq!( /// client.call(&mut context, 1).await.unwrap(), /// 2); @@ -767,7 +767,7 @@ where /// /// ```rust /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport, ClientMessage}; - /// use tarpc::context::{ClientContext, SharedContext}; + /// use tarpc::context::{SharedContext}; /// use futures::prelude::*; /// /// # #[cfg(not(feature = "tokio1"))] @@ -775,7 +775,7 @@ where /// # #[cfg(feature = "tokio1")] /// #[tokio::main] /// async fn main() { - /// let (tx, rx) = transport::channel::unbounded_for_client_server_context(); + /// let (tx, rx) = transport::channel::unbounded(); /// let requests = BaseChannel::new(server::Config::default(), rx).requests(); /// let client = client::new(client::Config::default(), tx).spawn(); /// tokio::spawn( @@ -783,7 +783,7 @@ where /// .for_each(|response| async move { /// tokio::spawn(response); /// }.boxed())); - /// let mut context = context::ClientContext::current(); + /// let mut context = context::SharedContext::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` @@ -872,7 +872,7 @@ impl InFlightRequest { /// use tarpc::{ /// ClientMessage, /// context, - /// context::{ClientContext, SharedContext}, + /// context::{SharedContext}, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, serve}, /// transport, @@ -881,7 +881,7 @@ impl InFlightRequest { /// /// #[tokio::main] /// async fn main() { - /// let (tx, rx) = transport::channel::unbounded_for_client_server_context(); + /// let (tx, rx) = transport::channel::unbounded(); /// let server = BaseChannel::new(server::Config::default(), rx); /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); /// tokio::spawn(dispatch); @@ -892,7 +892,7 @@ impl InFlightRequest { /// in_flight_request.execute(serve(|_, i| async move { Ok(i + 1) }.boxed())).await; /// } /// }); - /// let mut context = context::ClientContext::current(); + /// let mut context = context::SharedContext::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index 2baa27c89..6a71124b1 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -50,7 +50,7 @@ where /// use tarpc::{ /// ClientMessage, /// context, -/// context::{ClientContext, SharedContext}, +/// context::{SharedContext}, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, incoming::{Incoming, spawn_incoming}, serve}, /// transport, @@ -59,7 +59,7 @@ where /// /// #[tokio::main] /// async fn main() { -/// let (tx, rx) = transport::channel::unbounded_for_client_server_context(); +/// let (tx, rx) = transport::channel::unbounded(); /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); /// tokio::spawn(dispatch); /// @@ -67,7 +67,7 @@ where /// BaseChannel::new(server::Config::default(), rx) /// }).execute(serve(|_, i| async move { Ok(i + 1) }.boxed())); /// tokio::spawn(spawn_incoming(incoming)); -/// let mut context = context::ClientContext::current(); +/// let mut context = context::SharedContext::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 476f60738..de9a8afdc 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -6,14 +6,9 @@ //! Transports backed by in-memory channels. -use crate::context::{ClientContext, SharedContext}; -use crate::{ClientMessage, Response, Transport}; -use futures::future::{Ready}; -use futures::sink::With; -use futures::{Sink, SinkExt, Stream, TryStreamExt, task::*}; +use futures::{Sink, Stream, task::*}; use pin_project::pin_project; -use std::{error::Error, future, pin::Pin}; -use std::convert::identity; +use std::{error::Error, pin::Pin}; use tokio::sync::mpsc; /// Errors that occur in the sending or receiving of messages over a channel. @@ -44,97 +39,6 @@ pub fn unbounded() -> ( ) } -/// Returns two mapped unbounded channel peers. Each [`Stream`] yields items sent through the other's -/// [`Sink`]. -pub fn unbounded_mapped< - SerializedSinkItem, - SerializedItem, - ClientSinkItem, - ServerSinkItem, - ClientItem, - ServerItem, - F, - G, - H, - I, ->( - mut f: F, - mut g: G, - mut h: H, - mut i: I, -) -> ( - impl Transport, - impl Transport, -) -where - F: FnMut(ClientSinkItem) -> SerializedSinkItem, - G: FnMut(SerializedSinkItem) -> ServerSinkItem, - H: FnMut(SerializedItem) -> ClientItem, - I: FnMut(ServerItem) -> SerializedItem, -{ - let (client, server) = unbounded(); - - let client = client - .with(move |msg: ClientSinkItem| future::ready(Ok(f(msg)))) - .map_ok(move |msg: SerializedItem| h(msg)); - let server = server - .map_ok(move |msg: SerializedSinkItem| g(msg)) - .with(move |msg: ServerItem| future::ready(Ok(i(msg)))); - - (client, server) -} - -/// Convenience functino to return two mapped unbounded channel peers for a basechannel and a client implementation. Each [`Stream`] yields items sent through the other's -/// [`Sink`]. -pub fn unbounded_for_client_server_context() -> ( - impl Transport, Response>, - impl Transport, ClientMessage>, -) { - unbounded_mapped( - map_req_client_context_to_shared, - identity, - map_resp_shared_context_to_client, - identity, - ) -} - -/// Convenience function to map a ClientMessage with ClientContext to one with SharedContext. -fn map_req_client_context_to_shared( - msg: ClientMessage, -) -> ClientMessage { - msg.map_context(|ctx| ctx.shared_context) -} - -/// Convenience function to map a ClientMessage with SharedContext to one with ClientContext. -fn map_resp_shared_context_to_client( - msg: Response, -) -> Response { - msg.map_context(ClientContext::new) -} - -/// TODO: document -/// Yuck, but impl trait will loose our ability to do t.as_ref() -pub fn map_transport_to_client( - t: T, -) -> futures::stream::MapOk< - With< - T, - ClientMessage, - ClientMessage, - Ready, E>>, - fn(ClientMessage) -> Ready, E>>, - >, - fn(Response) -> Response, -> -where - T: Transport, Response>, - E: From -{ - let f: fn(ClientMessage) -> Ready, E>> = |resp| futures::future::ok(map_req_client_context_to_shared(resp)); - - t.with(f).map_ok(map_resp_shared_context_to_client) -} - /// A bi-directional channel backed by an [`UnboundedSender`](mpsc::UnboundedSender) /// and [`UnboundedReceiver`](mpsc::UnboundedReceiver). #[derive(Debug)] @@ -271,6 +175,7 @@ mod tests { use futures::{prelude::*, stream}; use std::io; use tracing::trace; + use crate::context::SharedContext; #[test] fn ensure_is_transport() { @@ -284,12 +189,12 @@ mod tests { let _ = tracing_subscriber::fmt::try_init(); let (client_channel, server_channel) = - transport::channel::unbounded_for_client_server_context(); + transport::channel::unbounded(); tokio::spawn( stream::once(future::ready(server_channel)) .map(BaseChannel::with_defaults) - .execute(serve(|_ctx, request: String| { + .execute(serve(|_ctx: &mut SharedContext, request: String| { async move { request.parse::().map_err(|_| { ServerError::new( @@ -308,10 +213,10 @@ mod tests { let client = client::new(client::Config::default(), client_channel).spawn(); let response1 = client - .call(&mut context::ClientContext::current(), "123".into()) + .call(&mut context::SharedContext::current(), "123".into()) .await; let response2 = client - .call(&mut context::ClientContext::current(), "abc".into()) + .call(&mut context::SharedContext::current(), "abc".into()) .await; trace!("response1: {:?}, response2: {:?}", response1, response2); diff --git a/tarpc/tests/compile_fail/must_use_request_dispatch.rs b/tarpc/tests/compile_fail/must_use_request_dispatch.rs index 2915d3237..a5238fe8b 100644 --- a/tarpc/tests/compile_fail/must_use_request_dispatch.rs +++ b/tarpc/tests/compile_fail/must_use_request_dispatch.rs @@ -1,5 +1,5 @@ use tarpc::client; - +use tarpc::context::SharedContext; #[tarpc::service] trait World { async fn hello(name: String) -> String; @@ -10,6 +10,6 @@ fn main() { #[deny(unused_must_use)] { - WorldClient::new(client::Config::default(), client_transport).dispatch; + WorldClient::::new(client::Config::default(), client_transport).dispatch; } } diff --git a/tarpc/tests/compile_fail/must_use_request_dispatch.stderr b/tarpc/tests/compile_fail/must_use_request_dispatch.stderr index e652cc8e8..e0ec77ff3 100644 --- a/tarpc/tests/compile_fail/must_use_request_dispatch.stderr +++ b/tarpc/tests/compile_fail/must_use_request_dispatch.stderr @@ -1,8 +1,8 @@ error: unused `RequestDispatch` that must be used --> tests/compile_fail/must_use_request_dispatch.rs:13:9 | -13 | WorldClient::new(client::Config::default(), client_transport).dispatch; - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +13 | WorldClient::::new(client::Config::default(), client_transport).dispatch; + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | note: the lint level is defined here --> tests/compile_fail/must_use_request_dispatch.rs:11:12 @@ -11,5 +11,5 @@ note: the lint level is defined here | ^^^^^^^^^^^^^^^ help: use `let _ = ...` to ignore the resulting value | -13 | let _ = WorldClient::new(client::Config::default(), client_transport).dispatch; +13 | let _ = WorldClient::::new(client::Config::default(), client_transport).dispatch; | +++++++ diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 1a1b5207a..6bcd255c4 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -1,7 +1,6 @@ use futures::prelude::*; -use tarpc::context::{ClientContext, SharedContext}; -use tarpc::transport::channel::{map_transport_to_client}; -use tarpc::{ClientMessage, serde_transport}; +use tarpc::context::{SharedContext}; +use tarpc::{serde_transport}; use tarpc::{ client, context, server::{BaseChannel, incoming::Incoming}, @@ -53,12 +52,10 @@ async fn test_call() -> anyhow::Result<()> { ); let transport = serde_transport::tcp::connect(addr, Json::default).await?; - let transport = map_transport_to_client(transport); - let client = ColorProtocolClient::new(client::Config::default(), transport).spawn(); let color = client - .get_opposite_color(&mut context::ClientContext::current(), TestData::White) + .get_opposite_color(&mut context::SharedContext::current(), TestData::White) .await?; assert_eq!(color, TestData::Black); diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 4692b17cd..7d1f96e18 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -4,7 +4,6 @@ use futures::{ prelude::*, }; use std::time::{Duration, Instant}; -use tarpc::transport::channel::{map_transport_to_client}; use tarpc::{ ClientMessage, client::{self}, @@ -38,7 +37,7 @@ impl Service for Server { #[tokio::test] async fn sequential() { - let (tx, rx) = transport::channel::unbounded_for_client_server_context(); + let (tx, rx) = channel::unbounded(); let client = client::new(client::Config::default(), tx).spawn(); let channel = BaseChannel::with_defaults(rx); @@ -51,7 +50,7 @@ async fn sequential() { ); assert_eq!( client - .call(&mut context::ClientContext::current(), 1) + .call(&mut context::SharedContext::current(), 1) .await .unwrap(), 2 @@ -79,14 +78,14 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = transport::channel::unbounded_for_client_server_context(); + let (tx, rx) = transport::channel::unbounded(); // Set up a client that initiates a long-lived request. // The request will complete in error when the server drops the connection. tokio::spawn(async move { let client = LoopClient::new(client::Config::default(), tx).spawn(); - let mut ctx = context::ClientContext::current(); + let mut ctx = context::SharedContext::current(); ctx.deadline = Instant::now() + Duration::from_secs(60 * 60); let _ = client.r#loop(&mut ctx).await; }); @@ -125,17 +124,16 @@ async fn serde_tcp() -> anyhow::Result<()> { ); let transport = serde_transport::tcp::connect(addr, Json::default).await?; - let transport = map_transport_to_client(transport); let client = ServiceClient::new(client::Config::default(), transport).spawn(); assert_matches!( client - .add(&mut context::ClientContext::current(), 1, 2) + .add(&mut context::SharedContext::current(), 1, 2) .await, Ok(3) ); assert_matches!( - client.hey(&mut context::ClientContext::current(), "Tim".to_string()).await, + client.hey(&mut context::SharedContext::current(), "Tim".to_string()).await, Ok(ref s) if s == "Hey, Tim." ); @@ -163,16 +161,15 @@ async fn serde_uds() -> anyhow::Result<()> { ); let transport = serde_transport::unix::connect(&sock, Json::default).await?; - let transport = map_transport_to_client(transport); let client = ServiceClient::new(client::Config::default(), transport).spawn(); // Save results using socket so we can clean the socket even if our test assertions fail let res1 = client - .add(&mut context::ClientContext::current(), 1, 2) + .add(&mut context::SharedContext::current(), 1, 2) .await; let res2 = client - .hey(&mut context::ClientContext::current(), "Tim".to_string()) + .hey(&mut context::SharedContext::current(), "Tim".to_string()) .await; assert_matches!(res1, Ok(3)); @@ -185,7 +182,7 @@ async fn serde_uds() -> anyhow::Result<()> { async fn concurrent() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = transport::channel::unbounded_for_client_server_context(); + let (tx, rx) = transport::channel::unbounded(); tokio::spawn( stream::once(ready(rx)) @@ -197,7 +194,7 @@ async fn concurrent() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let mut context = context::ClientContext::current(); + let mut context = context::SharedContext::current(); let req1 = client.add(&mut context, 1, 2); assert_matches!(req1.await, Ok(3)); @@ -215,7 +212,7 @@ async fn concurrent() -> anyhow::Result<()> { async fn concurrent_join() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = transport::channel::unbounded_for_client_server_context(); + let (tx, rx) = transport::channel::unbounded(); tokio::spawn( stream::once(ready(rx)) @@ -227,9 +224,9 @@ async fn concurrent_join() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let mut context1 = context::ClientContext::current(); - let mut context2 = context::ClientContext::current(); - let mut context3 = context::ClientContext::current(); + let mut context1 = context::SharedContext::current(); + let mut context2 = context::SharedContext::current(); + let mut context3 = context::SharedContext::current(); let req1 = client.add(&mut context1, 1, 2); let req2 = client.add(&mut context2, 3, 4); @@ -252,8 +249,7 @@ async fn spawn(fut: impl Future + Send + 'static) { async fn concurrent_join_all() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = transport::channel::unbounded_for_client_server_context(); - + let (tx, rx) = transport::channel::unbounded(); tokio::spawn( BaseChannel::with_defaults(rx) .execute(Server.serve()) @@ -262,8 +258,8 @@ async fn concurrent_join_all() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let mut context1 = context::ClientContext::current(); - let mut context2 = context::ClientContext::current(); + let mut context1 = context::SharedContext::current(); + let mut context2 = context::SharedContext::current(); let req1 = client.add(&mut context1, 1, 2); let req2 = client.add(&mut context2, 3, 4); @@ -292,7 +288,7 @@ async fn counter() -> anyhow::Result<()> { } } - let (tx, rx) = channel::unbounded_for_client_server_context(); + let (tx, rx) = channel::unbounded(); tokio::task::spawn(async move { let mut requests = BaseChannel::with_defaults(rx).requests(); @@ -305,11 +301,11 @@ async fn counter() -> anyhow::Result<()> { let client = CounterClient::new(client::Config::default(), tx).spawn(); assert_matches!( - client.count(&mut context::ClientContext::current()).await, + client.count(&mut context::SharedContext::current()).await, Ok(1) ); assert_matches!( - client.count(&mut context::ClientContext::current()).await, + client.count(&mut context::SharedContext::current()).await, Ok(2) ); From 8bd243a19d2f5a207e45539d7835b49a7ed83911 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Wed, 26 Nov 2025 16:15:32 +0100 Subject: [PATCH 15/26] fix merge conflict --- tarpc/src/context.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index 5ca5b8256..e89a7f044 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -10,7 +10,6 @@ use crate::trace::{self, TraceId}; use opentelemetry::trace::TraceContextExt; use static_assertions::assert_impl_all; -use std::ops::{Deref, DerefMut}; use std::{ convert::TryFrom, time::{Duration, Instant}, From cf3fa5371e0f4ed0fe978e46cf4c7ed3b23fd9d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Wed, 26 Nov 2025 16:16:14 +0100 Subject: [PATCH 16/26] run cargo fmt --- example-service/src/client.rs | 2 +- example-service/src/server.rs | 2 +- plugins/src/lib.rs | 27 +++---- tarpc/examples/compression.rs | 2 +- tarpc/examples/custom_transport.rs | 2 +- tarpc/examples/pubsub.rs | 27 +++++-- tarpc/examples/readme.rs | 5 +- tarpc/examples/tls_over_tcp.rs | 2 +- tarpc/examples/tracing.rs | 20 +++-- tarpc/src/client.rs | 73 ++++++++++--------- tarpc/src/client/in_flight_requests.rs | 15 +++- tarpc/src/client/stub.rs | 10 +-- tarpc/src/client/stub/load_balance.rs | 8 +- tarpc/src/client/stub/mock.rs | 12 +-- tarpc/src/context.rs | 7 +- tarpc/src/server.rs | 56 +++++++------- .../src/server/limits/requests_per_channel.rs | 14 ++-- tarpc/src/server/request_hook/after.rs | 12 +-- tarpc/src/server/request_hook/before.rs | 42 ++++------- .../server/request_hook/before_and_after.rs | 13 ++-- tarpc/src/server/testing.rs | 9 ++- tarpc/src/transport/channel.rs | 5 +- tarpc/tests/dataservice.rs | 4 +- tarpc/tests/service_functional.rs | 2 +- 24 files changed, 186 insertions(+), 185 deletions(-) diff --git a/example-service/src/client.rs b/example-service/src/client.rs index b8ff22c97..627e67504 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -8,7 +8,7 @@ use clap::Parser; use futures::{SinkExt, future}; use service::{WorldClient, init_tracing}; use std::{net::SocketAddr, time::Duration}; -use tarpc::context::{SharedContext}; +use tarpc::context::SharedContext; use tarpc::{client, tokio_serde::formats::Json}; use tokio::time::sleep; use tracing::Instrument; diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 019a2d7b1..9c9160e17 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -16,7 +16,7 @@ use std::{ net::{IpAddr, Ipv6Addr, SocketAddr}, time::Duration, }; -use tarpc::context::{SharedContext}; +use tarpc::context::SharedContext; use tarpc::{ ClientMessage, context, server::{self, Channel, incoming::Incoming}, diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index cf107d0ad..1a5b7e6db 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -550,22 +550,19 @@ impl ServiceGenerator<'_> { .. } = self; - let rpc_fns = rpcs - .iter() - .zip(return_types.iter()) - .map( - |( - RpcMethod { - attrs, ident, args, .. - }, - output, - )| { - quote! { - #( #attrs )* - async fn #ident(self, context: &mut Self::Context, #( #args ),*) -> #output; - } + let rpc_fns = rpcs.iter().zip(return_types.iter()).map( + |( + RpcMethod { + attrs, ident, args, .. }, - ); + output, + )| { + quote! { + #( #attrs )* + async fn #ident(self, context: &mut Self::Context, #( #args ),*) -> #output; + } + }, + ); let stub_doc = format!("The stub trait for service [`{service_ident}`]."); quote! { diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index f201521ad..1a3a7d566 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -9,13 +9,13 @@ use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt, prelude::*}; use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; use std::{io, io::Read, io::Write}; +use tarpc::context::SharedContext; use tarpc::{ client, context, serde_transport::tcp, server::{BaseChannel, Channel}, tokio_serde::formats::Bincode, }; -use tarpc::context::SharedContext; /// Type of compression that should be enabled on the request. The transport is free to ignore this. #[derive(Debug, PartialEq, Eq, Clone, Copy, Deserialize, Serialize)] diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index c9eb871ea..859bed0ed 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -6,7 +6,7 @@ use console_subscriber::Server; use futures::prelude::*; -use tarpc::context::{SharedContext}; +use tarpc::context::SharedContext; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 07a93becf..5e915e1b0 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -40,6 +40,9 @@ use futures::{ }; use opentelemetry::trace::TracerProvider as _; use publisher::Publisher as _; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; +use std::ops::Shl; use std::{ collections::HashMap, error::Error, @@ -47,9 +50,6 @@ use std::{ net::SocketAddr, sync::{Arc, Mutex, RwLock}, }; -use std::ops::Shl; -use serde::de::DeserializeOwned; -use serde::{Deserialize, Serialize}; use subscriber::Subscriber as _; use tarpc::context::{ExtractContext, SharedContext}; use tarpc::{ @@ -140,7 +140,8 @@ struct Subscription { #[derive(Debug)] struct Publisher { clients: Arc>>, - subscriptions: Arc>>>>, + subscriptions: + Arc>>>>, } impl Clone for Publisher { @@ -161,7 +162,17 @@ async fn spawn(fut: impl Future + Send + 'static) { tokio::spawn(fut); } -impl Publisher where ClientCtx: ExtractContext + From + Serialize + DeserializeOwned + Send + Sync + 'static { // TODO: Remove serde bounds here +impl Publisher +where + ClientCtx: ExtractContext + + From + + Serialize + + DeserializeOwned + + Send + + Sync + + 'static, +{ + // TODO: Remove serde bounds here async fn start(self) -> io::Result { let mut connecting_publishers = tcp::listen("localhost:0", Json::default).await?; @@ -178,7 +189,6 @@ impl Publisher where ClientCtx: ExtractContext Publisher where ClientCtx: ExtractContext publisher::Publisher for Publisher where ClientCtx: ExtractContext + From + Send + Sync + 'static { +impl publisher::Publisher for Publisher +where + ClientCtx: ExtractContext + From + Send + Sync + 'static, +{ type Context = ClientCtx; async fn publish(self, _: &mut Self::Context, topic: String, message: String) { info!("received message to publish."); diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index 359b4af8b..8c8d6619e 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -5,7 +5,7 @@ // https://opensource.org/licenses/MIT. use futures::prelude::*; -use tarpc::context::{SharedContext}; +use tarpc::context::SharedContext; use tarpc::{ ClientMessage, client, context, server::{self, Channel}, @@ -37,8 +37,7 @@ async fn spawn(fut: impl Future + Send + 'static) { #[tokio::main] async fn main() -> anyhow::Result<()> { - let (client_transport, server_transport) = - transport::channel::unbounded(); + let (client_transport, server_transport) = transport::channel::unbounded(); let server = server::BaseChannel::with_defaults(server_transport); tokio::spawn(server.execute(HelloServer.serve()).for_each(spawn)); diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index c203bf0b8..d67340449 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -10,7 +10,7 @@ use std::io::{self, BufReader, Cursor}; use std::net::{IpAddr, Ipv4Addr}; use std::sync::Arc; -use tarpc::context::{SharedContext}; +use tarpc::context::SharedContext; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index 525a16a47..77b19ba46 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -12,6 +12,7 @@ use crate::{ }; use futures::{future, prelude::*}; use opentelemetry::trace::TracerProvider as _; +use std::marker::PhantomData; use std::{ io, sync::{ @@ -19,7 +20,6 @@ use std::{ atomic::{AtomicBool, Ordering}, }, }; -use std::marker::PhantomData; use tarpc::context::{ExtractContext, SharedContext}; use tarpc::{ ClientMessage, RequestName, Response, ServerError, Transport, @@ -67,18 +67,22 @@ impl AddService for AddServer { #[derive(Clone)] struct DoubleServer { add_client: add::AddClient, - ghost: PhantomData + ghost: PhantomData, } impl DoubleService for DoubleServer where Stub: AddStub + Clone + Send + Sync + 'static, - ClientCtx: From + Send + Sync + 'static + ClientCtx: From + Send + Sync + 'static, { type Context = SharedContext; async fn double(self, _: &mut Self::Context, x: i32) -> Result { self.add_client - .add(&mut ClientCtx::from(context::SharedContext::current()), x, x) + .add( + &mut ClientCtx::from(context::SharedContext::current()), + x, + x, + ) .await .map_err(|e| e.to_string()) } @@ -141,7 +145,7 @@ fn make_stub( where Req: RequestName + Send + Sync + 'static, Resp: Send + Sync + 'static, - ClientCtx: ExtractContext + From + Send + Sync + 'static + ClientCtx: ExtractContext + From + Send + Sync + 'static, { let stub = load_balance::RoundRobin::new( backends @@ -196,7 +200,11 @@ async fn main() -> anyhow::Result<()> { .filter_map(|r| future::ready(r.ok())); let addr = double_listener.get_ref().local_addr(); let double_server = double_listener.map(BaseChannel::with_defaults).take(1); - let server = DoubleServer::<_, SharedContext> { add_client, ghost: PhantomData }.serve(); + let server = DoubleServer::<_, SharedContext> { + add_client, + ghost: PhantomData, + } + .serve(); tokio::spawn(spawn_incoming(double_server.execute(server))); let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 40ba7e461..27856c729 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -19,6 +19,7 @@ use crate::{ use futures::{prelude::*, ready, stream::Fuse, task::*}; use in_flight_requests::InFlightRequests; use pin_project::pin_project; +use std::marker::PhantomData; use std::{ any::Any, convert::TryFrom, @@ -30,7 +31,6 @@ use std::{ }, time::SystemTime, }; -use std::marker::PhantomData; use tokio::sync::{mpsc, oneshot}; use tracing::Span; @@ -105,7 +105,7 @@ pub struct Channel { next_request_id: Arc, ///TODO: Document - ghost: PhantomData + ghost: PhantomData, } impl Clone for Channel { @@ -114,7 +114,7 @@ impl Clone for Channel { to_dispatch: self.to_dispatch.clone(), cancellation: self.cancellation.clone(), next_request_id: self.next_request_id.clone(), - ghost: PhantomData + ghost: PhantomData, } } } @@ -122,7 +122,7 @@ impl Clone for Channel { impl Channel where Req: RequestName, - ClientCtx: ExtractContext + ClientCtx: ExtractContext, { /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that /// resolves to the response. @@ -135,11 +135,7 @@ where otel.kind = "client", otel.name = %request.name()) )] - pub async fn call( - &self, - ctx: &mut ClientCtx, - request: Req, - ) -> Result { + pub async fn call(&self, ctx: &mut ClientCtx, request: Req) -> Result { let span = Span::current(); let mut shared_context = ctx.extract(); shared_context.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { @@ -148,7 +144,10 @@ where ); shared_context.trace_context.new_child() }); - span.record("rpc.trace_id", tracing::field::display(shared_context.trace_id())); + span.record( + "rpc.trace_id", + tracing::field::display(shared_context.trace_id()), + ); let (response_completion, mut response) = oneshot::channel(); let request_id = u64::try_from(self.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); @@ -266,7 +265,7 @@ where to_dispatch, cancellation, next_request_id: Arc::new(AtomicUsize::new(0)), - ghost: PhantomData + ghost: PhantomData, }, dispatch: RequestDispatch { config, @@ -275,7 +274,7 @@ where in_flight_requests: InFlightRequests::default(), pending_requests, terminal_error: None, - ghost: PhantomData + ghost: PhantomData, }, } } @@ -309,11 +308,9 @@ pub struct RequestDispatch { impl RequestDispatch where C: Transport, Response>, - ClientCtx: ExtractContext + From + ClientCtx: ExtractContext + From, { - fn in_flight_requests<'a>( - self: &'a mut Pin<&mut Self>, - ) -> &'a mut InFlightRequests { + fn in_flight_requests<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests { self.as_mut().project().in_flight_requests } @@ -595,7 +592,10 @@ where fn complete(mut self: Pin<&mut Self>, response: Response) -> bool { if let Some(span) = self.in_flight_requests().complete_request( response.request_id, - response.message.map_err(RpcError::Server).map(|m| (response.context.extract(), m)), + response + .message + .map_err(RpcError::Server) + .map(|m| (response.context.extract(), m)), ) { let _entered = span.enter(); tracing::debug!("ReceiveResponse"); @@ -673,7 +673,7 @@ where impl Future for RequestDispatch where C: Transport, Response>, - ClientCtx: ExtractContext + From + ClientCtx: ExtractContext + From, { type Output = Result<(), ChannelError>; @@ -704,7 +704,8 @@ where /// the lifecycle of the request. #[derive(Debug)] struct DispatchRequest { - pub ctx: context::SharedContext, ///TODO: <-- this should be a &mut ClientContext + pub ctx: context::SharedContext, + ///TODO: <-- this should be a &mut ClientContext pub span: Span, pub request_id: u64, pub request: Req, @@ -716,7 +717,7 @@ mod tests { use super::{ Channel, DispatchRequest, RequestDispatch, ResponseGuard, RpcError, cancellations, }; - use crate::context::{SharedContext}; + use crate::context::SharedContext; use crate::{ ChannelError, ClientMessage, Response, client::{Config, in_flight_requests::InFlightRequests}, @@ -790,7 +791,8 @@ mod tests { async fn dispatch_response_doesnt_cancel_after_complete() { let (cancellation, mut canceled_requests) = cancellations(); let (tx, mut response) = oneshot::channel(); - tx.send(Ok((SharedContext::current(), "well done"))).unwrap(); + tx.send(Ok((SharedContext::current(), "well done"))) + .unwrap(); // resp's drop() is run, but should not send a cancel message. ResponseGuard { response: &mut response, @@ -902,7 +904,8 @@ mod tests { #[tokio::test] async fn test_permit_before_transport_error() { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); - let (mut dispatch, mut channel, mut cx) = set_up_always_err::(TransportError::Flush); + let (mut dispatch, mut channel, mut cx) = + set_up_always_err::(TransportError::Flush); let (tx, mut rx) = oneshot::channel(); // reserve succeeds let permit = reserve_for_send(&mut channel, tx, &mut rx).await; @@ -1003,13 +1006,18 @@ mod tests { fn set_up_always_err( cause: TransportError, ) -> ( - Pin>>>, + Pin< + Box< + RequestDispatch>, + >, + >, Channel, Context<'static>, ) { let (to_dispatch, pending_requests) = mpsc::channel(1); let (cancellation, canceled_requests) = cancellations(); - let transport: AlwaysErrorTransport = AlwaysErrorTransport(cause, PhantomData); + let transport: AlwaysErrorTransport = + AlwaysErrorTransport(cause, PhantomData); let dispatch = Box::pin(RequestDispatch:: { transport: transport.fuse(), pending_requests, @@ -1017,19 +1025,19 @@ mod tests { in_flight_requests: InFlightRequests::default(), config: Config::default(), terminal_error: None, - ghost: PhantomData + ghost: PhantomData, }); let channel = Channel { to_dispatch, cancellation, next_request_id: Arc::new(AtomicUsize::new(0)), - ghost: PhantomData + ghost: PhantomData, }; let cx = Context::from_waker(noop_waker_ref()); (dispatch, channel, cx) } - struct AlwaysErrorTransport(TransportError, PhantomData<( I, ClientCtx)>); + struct AlwaysErrorTransport(TransportError, PhantomData<(I, ClientCtx)>); #[derive(Debug, Error, PartialEq, Eq, Clone, Copy)] enum TransportError { @@ -1096,10 +1104,7 @@ mod tests { String, String, ClientCtx, - UnboundedChannel< - Response, - ClientMessage, - >, + UnboundedChannel, ClientMessage>, >, >, >, @@ -1119,14 +1124,14 @@ mod tests { in_flight_requests: InFlightRequests::default(), config: Config::default(), terminal_error: None, - ghost: PhantomData + ghost: PhantomData, }; let channel = Channel { to_dispatch, cancellation, next_request_id: Arc::new(AtomicUsize::new(0)), - ghost: PhantomData + ghost: PhantomData, }; (Box::pin(dispatch), channel, server_channel) @@ -1165,7 +1170,7 @@ mod tests { let permit = channel.to_dispatch.reserve().await.unwrap(); |request| { let request_id = - u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); + u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); let request = DispatchRequest { ctx: SharedContext::current(), span: Span::current(), diff --git a/tarpc/src/client/in_flight_requests.rs b/tarpc/src/client/in_flight_requests.rs index 5b648098b..0ea5ba5ac 100644 --- a/tarpc/src/client/in_flight_requests.rs +++ b/tarpc/src/client/in_flight_requests.rs @@ -1,4 +1,9 @@ -use crate::{trace, util::{Compact, TimeUntil}}; +use crate::client::RpcError; +use crate::context::SharedContext; +use crate::{ + trace, + util::{Compact, TimeUntil}, +}; use fnv::FnvHashMap; use std::time::Instant; use std::{ @@ -8,8 +13,6 @@ use std::{ use tokio::sync::oneshot; use tokio_util::time::delay_queue::{self, DelayQueue}; use tracing::Span; -use crate::client::RpcError; -use crate::context::{SharedContext}; /// Requests already written to the wire that haven't yet received responses. #[derive(Debug)] @@ -78,7 +81,11 @@ impl InFlightRequests { } /// Removes a request without aborting. Returns true if the request was found. - pub fn complete_request(&mut self, request_id: u64, result: Result<(SharedContext, Res), RpcError>) -> Option { + pub fn complete_request( + &mut self, + request_id: u64, + result: Result<(SharedContext, Res), RpcError>, + ) -> Option { if let Some(request_data) = self.request_data.remove(&request_id) { self.request_data.compact(0.1); self.deadlines.remove(&request_data.deadline_key); diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index 992f6d611..51cececae 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -1,11 +1,11 @@ //! Provides a Stub trait, implemented by types that can call remote services. +use crate::context::{ExtractContext, SharedContext}; use crate::{ RequestName, client::{Channel, RpcError}, server::Serve, }; -use crate::context::{ExtractContext, SharedContext}; pub mod load_balance; pub mod retry; @@ -37,17 +37,13 @@ pub trait Stub { impl Stub for Channel where Req: RequestName, - ClientCtx: ExtractContext + ClientCtx: ExtractContext, { type Req = Req; type Resp = Resp; type ClientCtx = ClientCtx; - async fn call( - &self, - ctx: &mut Self::ClientCtx, - request: Req, - ) -> Result { + async fn call(&self, ctx: &mut Self::ClientCtx, request: Req) -> Result { Self::call(self, ctx, request).await } } diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs index 60efafc91..9664a2aa7 100644 --- a/tarpc/src/client/stub/load_balance.rs +++ b/tarpc/src/client/stub/load_balance.rs @@ -5,9 +5,7 @@ pub use round_robin::RoundRobin; /// Provides a stub that load-balances with a simple round-robin strategy. mod round_robin { - use crate::{ - client::{RpcError, stub}, - }; + use crate::client::{RpcError, stub}; use cycle::AtomicCycle; impl stub::Stub for RoundRobin @@ -98,9 +96,7 @@ mod round_robin { /// Each request is hashed, then mapped to a stub based on the hash. Equivalent requests will use /// the same stub. mod consistent_hash { - use crate::{ - client::{RpcError, stub} - }; + use crate::client::{RpcError, stub}; use std::{ collections::hash_map::RandomState, hash::{BuildHasher, Hash}, diff --git a/tarpc/src/client/stub/mock.rs b/tarpc/src/client/stub/mock.rs index 577ef5362..171f8918e 100644 --- a/tarpc/src/client/stub/mock.rs +++ b/tarpc/src/client/stub/mock.rs @@ -2,13 +2,13 @@ use crate::{ RequestName, ServerError, client::{RpcError, stub::Stub}, }; -use std::{collections::HashMap, hash::Hash, io}; use std::marker::PhantomData; +use std::{collections::HashMap, hash::Hash, io}; /// A mock stub that returns user-specified responses. pub struct Mock { responses: HashMap, - ghost: PhantomData + ghost: PhantomData, } impl Mock @@ -19,7 +19,7 @@ where pub fn new(responses: [(Req, Resp); N]) -> Self { Self { responses: HashMap::from(responses), - ghost: PhantomData + ghost: PhantomData, } } } @@ -33,11 +33,7 @@ where type Resp = Resp; type ClientCtx = ServerCtx; - async fn call( - &self, - _: &mut Self::ClientCtx, - request: Self::Req, - ) -> Result { + async fn call(&self, _: &mut Self::ClientCtx, request: Self::Req) -> Result { self.responses .get(&request) .cloned() diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index e89a7f044..bc357e50f 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -34,7 +34,7 @@ pub struct SharedContext { /// When a service handles a request by making requests itself, those requests should /// include the same `trace_id` as that included on the original request. This way, /// users can trace related actions across a distributed system. - pub trace_context: trace::Context + pub trace_context: trace::Context, } ///TODO @@ -45,7 +45,10 @@ pub trait ExtractContext { fn update(&mut self, value: Ctx); } -impl ExtractContext for T where T: Clone { +impl ExtractContext for T +where + T: Clone, +{ fn extract(&self) -> T { self.clone() } diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 7d345a203..1ed69fcd8 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -6,10 +6,11 @@ //! Provides a server that concurrently handles many connections sending multiplexed requests. +use crate::context::{ExtractContext, SharedContext}; use crate::{ ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport, cancellations::{CanceledRequests, RequestCancellation, cancellations}, - context::{SpanExt}, + context::SpanExt, trace, util::TimeUntil, }; @@ -27,7 +28,6 @@ use std::{ convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin, sync::Arc, time::SystemTime, }; use tracing::{Span, info_span, instrument::Instrument}; -use crate::context::{ExtractContext, SharedContext}; mod in_flight_requests; pub mod request_hook; @@ -59,7 +59,10 @@ impl Default for Config { impl Config { /// Returns a channel backed by `transport` and configured with `self`. - pub fn channel(self, transport: T) -> BaseChannel + pub fn channel( + self, + transport: T, + ) -> BaseChannel where T: Transport, ClientMessage>, ServerCtx: ExtractContext, @@ -113,7 +116,10 @@ impl Copy for ServeFn where F: /// Result>`. pub fn serve(f: F) -> ServeFn where - for<'a> F: FnOnce(&'a mut ServerCtx, Req) -> Pin> + 'a + Send>>, + for<'a> F: FnOnce( + &'a mut ServerCtx, + Req, + ) -> Pin> + 'a + Send>>, { ServeFn { f, @@ -168,7 +174,7 @@ pub struct BaseChannel { impl BaseChannel where T: Transport, ClientMessage>, - ServerCtx: ExtractContext + ServerCtx: ExtractContext, { /// Creates a new channel backed by `transport` and configured with `config`. pub fn new(config: Config, transport: T) -> Self { @@ -443,7 +449,7 @@ where impl Stream for BaseChannel where T: Transport, ClientMessage>, - ServerCtx: ExtractContext + ServerCtx: ExtractContext, { type Item = Result, ChannelError>; @@ -548,11 +554,12 @@ where } } -impl Sink> for BaseChannel +impl Sink> + for BaseChannel where T: Transport, ClientMessage>, T::Error: Error, - ServerCtx: ExtractContext + ServerCtx: ExtractContext, { type Error = ChannelError; @@ -610,7 +617,6 @@ where T: Transport, ClientMessage>, ServerCtx: ExtractContext, { - type Req = Req; type Resp = Resp; type Transport = T; @@ -995,6 +1001,7 @@ mod tests { request_hook::{AfterRequest, BeforeRequest, RequestHook}, serve, }; + use crate::context::{ExtractContext, SharedContext}; use crate::{ ClientMessage, Request, Response, ServerError, context, trace, transport::channel::{self, UnboundedChannel}, @@ -1012,7 +1019,6 @@ mod tests { task::Poll, time::{Duration, Instant}, }; - use crate::context::{ExtractContext, SharedContext}; fn test_channel() -> ( Pin< @@ -1024,7 +1030,7 @@ mod tests { ClientMessage, Response, >, - SharedContext + SharedContext, >, >, >, @@ -1045,9 +1051,8 @@ mod tests { ClientMessage, Response, >, - SharedContext + SharedContext, >, - >, >, >, @@ -1073,7 +1078,7 @@ mod tests { ClientMessage, Response, >, - SharedContext + SharedContext, >, >, >, @@ -1114,12 +1119,11 @@ mod tests { #[tokio::test] async fn serve_before_mutates_context() -> anyhow::Result<()> { struct SetDeadline(Instant); - impl BeforeRequest for SetDeadline where ServerCtx: ExtractContext { - async fn before( - &mut self, - ctx: &mut ServerCtx, - _: &Req, - ) -> Result<(), ServerError> { + impl BeforeRequest for SetDeadline + where + ServerCtx: ExtractContext, + { + async fn before(&mut self, ctx: &mut ServerCtx, _: &Req) -> Result<(), ServerError> { let mut inner = ctx.extract(); inner.deadline = self.0; ctx.update(inner); @@ -1159,21 +1163,13 @@ mod tests { } } impl BeforeRequest for PrintLatency { - async fn before( - &mut self, - _: &mut ServerCtx, - _: &Req, - ) -> Result<(), ServerError> { + async fn before(&mut self, _: &mut ServerCtx, _: &Req) -> Result<(), ServerError> { self.start = Instant::now(); Ok(()) } } impl AfterRequest for PrintLatency { - async fn after( - &mut self, - _: &mut ServerCtx, - _: &mut Result, - ) { + async fn after(&mut self, _: &mut ServerCtx, _: &mut Result) { tracing::debug!("Elapsed: {:?}", self.start.elapsed()); } } diff --git a/tarpc/src/server/limits/requests_per_channel.rs b/tarpc/src/server/limits/requests_per_channel.rs index deb723bda..34b372510 100644 --- a/tarpc/src/server/limits/requests_per_channel.rs +++ b/tarpc/src/server/limits/requests_per_channel.rs @@ -180,6 +180,7 @@ where mod tests { use super::*; + use crate::context::SharedContext; use crate::server::{ TrackedRequest, testing::{self, FakeChannel, PollExt}, @@ -190,7 +191,6 @@ mod tests { time::{Duration, Instant}, }; use tracing::Span; - use crate::context::{SharedContext}; #[tokio::test] async fn throttler_in_flight_requests() { @@ -270,9 +270,10 @@ mod tests { ghost: PhantomData In>, } impl PendingSink<(), ()> { - pub fn default() - -> PendingSink>, Response> - { + pub fn default() -> PendingSink< + io::Result>, + Response, + > { PendingSink { ghost: PhantomData } } } @@ -298,7 +299,10 @@ mod tests { } } impl Channel - for PendingSink>, Response> + for PendingSink< + io::Result>, + Response, + > { type Req = Req; type Resp = Resp; diff --git a/tarpc/src/server/request_hook/after.rs b/tarpc/src/server/request_hook/after.rs index ce6319e25..1fa3cee51 100644 --- a/tarpc/src/server/request_hook/after.rs +++ b/tarpc/src/server/request_hook/after.rs @@ -15,11 +15,7 @@ pub trait AfterRequest { /// The function that is called after request execution. /// /// The hook can modify the request context and the response. - async fn after( - &mut self, - ctx: &mut ServerCtx, - resp: &mut Result, - ); + async fn after(&mut self, ctx: &mut ServerCtx, resp: &mut Result); } impl AfterRequest for F @@ -27,11 +23,7 @@ where F: FnMut(&mut ServerCtx, &mut Result) -> Fut, Fut: Future, { - async fn after( - &mut self, - ctx: &mut ServerCtx, - resp: &mut Result, - ) { + async fn after(&mut self, ctx: &mut ServerCtx, resp: &mut Result) { self(ctx, resp).await } } diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs index 3e2e091c8..1552a0b49 100644 --- a/tarpc/src/server/request_hook/before.rs +++ b/tarpc/src/server/request_hook/before.rs @@ -6,9 +6,9 @@ //! Provides a hook that runs before request execution. -use std::marker::PhantomData; use crate::{ServerError, server::Serve}; use futures::prelude::*; +use std::marker::PhantomData; /// A hook that runs before request execution. #[allow(async_fn_in_trait)] @@ -20,11 +20,7 @@ pub trait BeforeRequest { /// /// This function can also modify the request context. This could be used, for example, to /// enforce a maximum deadline on all requests. - async fn before( - &mut self, - ctx: &mut ServerCtx, - req: &Req, - ) -> Result<(), ServerError>; + async fn before(&mut self, ctx: &mut ServerCtx, req: &Req) -> Result<(), ServerError>; } /// A list of hooks that run in order before request execution. @@ -64,11 +60,7 @@ where F: FnMut(&mut ServerCtx, &Req) -> Fut, Fut: Future>, { - async fn before( - &mut self, - ctx: &mut ServerCtx, - req: &Req, - ) -> Result<(), ServerError> { + async fn before(&mut self, ctx: &mut ServerCtx, req: &Req) -> Result<(), ServerError> { self(ctx, req).await } } @@ -77,7 +69,7 @@ where pub struct HookThenServe { serve: Serv, hook: Hook, - ghost: PhantomData + ghost: PhantomData, } impl Clone for HookThenServe { @@ -88,7 +80,11 @@ impl Clone for HookThenServe HookThenServe { pub(crate) fn new(serve: Serv, hook: Hook) -> Self { - Self { serve, hook, ghost: PhantomData } + Self { + serve, + hook, + ghost: PhantomData, + } } } @@ -101,11 +97,7 @@ where type Req = Serv::Req; type Resp = Serv::Resp; - async fn serve( - self, - ctx: &mut ServerCtx, - req: Self::Req, - ) -> Result { + async fn serve(self, ctx: &mut ServerCtx, req: Self::Req) -> Result { let HookThenServe { serve, mut hook, .. } = self; @@ -154,14 +146,10 @@ pub struct BeforeRequestCons(First, Rest); #[derive(Clone, Copy)] pub struct BeforeRequestNil; -impl, Rest: BeforeRequest, ServerCtx> BeforeRequest - for BeforeRequestCons +impl, Rest: BeforeRequest, ServerCtx> + BeforeRequest for BeforeRequestCons { - async fn before( - &mut self, - ctx: &mut ServerCtx, - req: &Req, - ) -> Result<(), ServerError> { + async fn before(&mut self, ctx: &mut ServerCtx, req: &Req) -> Result<(), ServerError> { let BeforeRequestCons(first, rest) = self; first.before(ctx, req).await?; rest.before(ctx, req).await?; @@ -175,8 +163,8 @@ impl BeforeRequest for BeforeRequestNil { } } -impl, Rest: BeforeRequestList, ServerCtx> BeforeRequestList - for BeforeRequestCons +impl, Rest: BeforeRequestList, ServerCtx> + BeforeRequestList for BeforeRequestCons { type Then = BeforeRequestCons> diff --git a/tarpc/src/server/request_hook/before_and_after.rs b/tarpc/src/server/request_hook/before_and_after.rs index 080c53b21..f3653a513 100644 --- a/tarpc/src/server/request_hook/before_and_after.rs +++ b/tarpc/src/server/request_hook/before_and_after.rs @@ -27,7 +27,9 @@ impl HookThenServeThenHook Clone for HookThenServeThenHook { +impl Clone + for HookThenServeThenHook +{ fn clone(&self) -> Self { Self { serve: self.serve.clone(), @@ -37,7 +39,8 @@ impl Clone for HookThenServeThen } } -impl Serve for HookThenServeThenHook +impl Serve + for HookThenServeThenHook where Req: RequestName, Serv: Serve, @@ -47,11 +50,7 @@ where type Resp = Resp; type ServerCtx = ServerCtx; - async fn serve( - self, - ctx: &mut ServerCtx, - req: Req, - ) -> Result { + async fn serve(self, ctx: &mut ServerCtx, req: Req) -> Result { let HookThenServeThenHook { serve, mut hook, .. } = self; diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index 9a941f711..ce409dd85 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -4,7 +4,7 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -use crate::context::{SharedContext}; +use crate::context::SharedContext; use crate::{ Request, Response, cancellations::{CanceledRequests, RequestCancellation, cancellations}, @@ -94,7 +94,9 @@ where } } -impl FakeChannel>, Response> { +impl + FakeChannel>, Response> +{ pub fn push_req(&mut self, id: u64, message: Req) { let (_, abort_registration) = futures::future::AbortHandle::new_pair(); let (request_cancellation, _) = cancellations(); @@ -120,7 +122,8 @@ impl FakeChannel>, Resp impl FakeChannel<(), ()> { pub fn default() - -> FakeChannel>, Response> { + -> FakeChannel>, Response> + { let (request_cancellation, canceled_requests) = cancellations(); let mut x = anymap3::AnyMap::new(); diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index de9a8afdc..a698136f0 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -161,6 +161,7 @@ impl Sink for Channel { #[cfg(all(test, feature = "tokio1"))] mod tests { + use crate::context::SharedContext; use crate::{ ServerError, client::{self, RpcError}, @@ -175,7 +176,6 @@ mod tests { use futures::{prelude::*, stream}; use std::io; use tracing::trace; - use crate::context::SharedContext; #[test] fn ensure_is_transport() { @@ -188,8 +188,7 @@ mod tests { async fn integration() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (client_channel, server_channel) = - transport::channel::unbounded(); + let (client_channel, server_channel) = transport::channel::unbounded(); tokio::spawn( stream::once(future::ready(server_channel)) diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 6bcd255c4..a39922666 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -1,6 +1,6 @@ use futures::prelude::*; -use tarpc::context::{SharedContext}; -use tarpc::{serde_transport}; +use tarpc::context::SharedContext; +use tarpc::serde_transport; use tarpc::{ client, context, server::{BaseChannel, incoming::Incoming}, diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 7d1f96e18..fd54b3db6 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -4,6 +4,7 @@ use futures::{ prelude::*, }; use std::time::{Duration, Instant}; +use tarpc::context::SharedContext; use tarpc::{ ClientMessage, client::{self}, @@ -13,7 +14,6 @@ use tarpc::{ transport::channel, }; use tokio::join; -use tarpc::context::SharedContext; #[tarpc_plugins::service] trait Service { From 34a87d65f3f9abaaf27a20bc56838dc06e602a88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Wed, 26 Nov 2025 16:35:27 +0100 Subject: [PATCH 17/26] cleanup... --- example-service/src/client.rs | 12 +- example-service/src/server.rs | 4 +- plugins/src/lib.rs | 7 +- plugins/tests/service.rs | 10 +- tarpc/examples/compression.rs | 6 +- tarpc/examples/custom_transport.rs | 8 +- tarpc/examples/pubsub.rs | 24 ++-- tarpc/examples/readme.rs | 6 +- tarpc/examples/tls_over_tcp.rs | 8 +- tarpc/examples/tracing.rs | 16 +-- tarpc/src/client.rs | 67 +++++------ tarpc/src/client/in_flight_requests.rs | 16 +-- tarpc/src/client/stub.rs | 14 +-- tarpc/src/client/stub/load_balance.rs | 6 +- tarpc/src/context.rs | 10 +- tarpc/src/lib.rs | 12 +- tarpc/src/server.rs | 107 +++++++++--------- tarpc/src/server/incoming.rs | 3 +- .../src/server/limits/requests_per_channel.rs | 14 +-- tarpc/src/server/request_hook.rs | 10 +- tarpc/src/server/request_hook/before.rs | 4 +- tarpc/src/server/testing.rs | 17 ++- tarpc/src/transport/channel.rs | 7 +- .../compile_fail/must_use_request_dispatch.rs | 4 +- .../must_use_request_dispatch.stderr | 6 +- tarpc/tests/dataservice.rs | 6 +- tarpc/tests/service_functional.rs | 36 +++--- 27 files changed, 207 insertions(+), 233 deletions(-) diff --git a/example-service/src/client.rs b/example-service/src/client.rs index 627e67504..64b2e0a89 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -5,11 +5,9 @@ // https://opensource.org/licenses/MIT. use clap::Parser; -use futures::{SinkExt, future}; use service::{WorldClient, init_tracing}; use std::{net::SocketAddr, time::Duration}; -use tarpc::context::SharedContext; -use tarpc::{client, tokio_serde::formats::Json}; +use tarpc::{client, context, tokio_serde::formats::Json}; use tokio::time::sleep; use tracing::Instrument; @@ -31,15 +29,13 @@ async fn main() -> anyhow::Result<()> { let mut transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default); transport.config_mut().max_frame_length(usize::MAX); - let transport = transport.await?; - // WorldClient is generated by the service attribute. It has a constructor `new` that takes a // config and any Transport as input. - let client = WorldClient::new(client::Config::default(), transport).spawn(); + let client = WorldClient::new(client::Config::default(), transport.await?).spawn(); let hello = async move { - let mut context = SharedContext::current(); - let mut context2 = SharedContext::current(); + let mut context = context::Context::current(); + let mut context2 = context::Context::current(); // Send the request twice, just to be safe! ;) tokio::select! { diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 9c9160e17..302336c57 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -16,7 +16,7 @@ use std::{ net::{IpAddr, Ipv6Addr, SocketAddr}, time::Duration, }; -use tarpc::context::SharedContext; +use tarpc::context::Context; use tarpc::{ ClientMessage, context, server::{self, Channel, incoming::Incoming}, @@ -37,7 +37,7 @@ struct Flags { struct HelloServer(SocketAddr); impl World for HelloServer { - type Context = SharedContext; + type Context = context::Context; async fn hello(self, _: &mut Self::Context, name: String) -> String { let sleep_time = Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng())); diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 1a5b7e6db..b8f1ff826 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -377,8 +377,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// ```no_run /// use tarpc::{client, transport, service, server::{self, Channel}}; /// use futures_util::{TryStreamExt, sink::SinkExt};/// -/// -/// use tarpc::context::SharedContext; +/// use tarpc::context; /// /// #[service] /// pub trait Calculator { @@ -404,7 +403,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// #[derive(Clone)] /// struct CalculatorServer; /// impl Calculator for CalculatorServer { -/// type Context = SharedContext; +/// type Context = context::Context; /// async fn add(self, context: &mut Self::Context, a: i32, b: i32) -> i32 { /// a + b /// } @@ -568,7 +567,7 @@ impl ServiceGenerator<'_> { quote! { #( #attrs )* #vis trait #service_ident: ::core::marker::Sized { - type Context: ::tarpc::context::ExtractContext<::tarpc::context::SharedContext>; + type Context: ::tarpc::context::ExtractContext<::tarpc::context::Context>; #( #rpc_fns )* diff --git a/plugins/tests/service.rs b/plugins/tests/service.rs index d8213f4d4..3bd8b4c4e 100644 --- a/plugins/tests/service.rs +++ b/plugins/tests/service.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; use std::hash::Hash; use tarpc::context; -use tarpc::context::SharedContext; +use tarpc::context::Context; #[test] fn att_service_trait() { @@ -13,10 +13,10 @@ fn att_service_trait() { } impl Foo for () { - type Context = SharedContext; + type Context = context::Context; async fn two_part( self, - _: &mut context::SharedContext, + _: &mut context::Context, s: String, i: i32, ) -> (String, i32) { @@ -44,7 +44,7 @@ fn raw_idents() { } impl r#trait for () { - type Context = SharedContext; + type Context = context::Context; async fn r#await( self, _: &mut Self::Context, @@ -72,7 +72,7 @@ fn service_with_cfg_rpc() { } impl Foo for () { - type Context = SharedContext; + type Context = context::Context; async fn foo(self, _: &mut Self::Context) {} } } diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index 1a3a7d566..3c5fa6fcf 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -9,7 +9,7 @@ use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt, prelude::*}; use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; use std::{io, io::Read, io::Write}; -use tarpc::context::SharedContext; +use tarpc::context::Context; use tarpc::{ client, context, serde_transport::tcp, @@ -109,7 +109,7 @@ pub trait World { struct HelloServer; impl World for HelloServer { - type Context = SharedContext; + type Context = context::Context; async fn hello(self, _: &mut Self::Context, name: String) -> String { format!("Hey, {name}!") } @@ -140,7 +140,7 @@ async fn main() -> anyhow::Result<()> { println!( "{}", client - .hello(&mut context::SharedContext::current(), "friend".into()) + .hello(&mut context::Context::current(), "friend".into()) .await? ); Ok(()) diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index 859bed0ed..80f8a03be 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -6,8 +6,8 @@ use console_subscriber::Server; use futures::prelude::*; -use tarpc::context::SharedContext; -use tarpc::serde_transport as transport; +use tarpc::context::Context; +use tarpc::{context, serde_transport as transport}; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; @@ -22,7 +22,7 @@ pub trait PingService { struct Service; impl PingService for Service { - type Context = SharedContext; + type Context = context::Context; async fn ping(self, _: &mut Self::Context) {} } #[tokio::main] @@ -53,7 +53,7 @@ async fn main() -> anyhow::Result<()> { let transport = transport::new(codec_builder.new_framed(conn), Bincode::default()); PingServiceClient::new(Default::default(), transport) .spawn() - .ping(&mut SharedContext::current()) + .ping(&mut context::Context::current()) .await?; Ok(()) diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 5e915e1b0..3cf95b27e 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -51,7 +51,7 @@ use std::{ sync::{Arc, Mutex, RwLock}, }; use subscriber::Subscriber as _; -use tarpc::context::{ExtractContext, SharedContext}; +use tarpc::context::{ExtractContext}; use tarpc::{ ClientMessage, client, context, serde_transport::tcp, @@ -84,7 +84,7 @@ struct Subscriber { } impl subscriber::Subscriber for Subscriber { - type Context = SharedContext; + type Context = context::Context; async fn topics(self, _: &mut Self::Context) -> Vec { self.topics.clone() } @@ -164,8 +164,8 @@ async fn spawn(fut: impl Future + Send + 'static) { impl Publisher where - ClientCtx: ExtractContext - + From + ClientCtx: ExtractContext + + From + Serialize + DeserializeOwned + Send @@ -235,7 +235,7 @@ where ) { // Populate the topics if let Ok(topics) = subscriber - .topics(&mut ClientCtx::from(context::SharedContext::current())) + .topics(&mut ClientCtx::from(context::Context::current())) .await { self.clients.lock().unwrap().insert( @@ -291,7 +291,7 @@ where impl publisher::Publisher for Publisher where - ClientCtx: ExtractContext + From + Send + Sync + 'static, + ClientCtx: ExtractContext + From + Send + Sync + 'static, { type Context = ClientCtx; async fn publish(self, _: &mut Self::Context, topic: String, message: String) { @@ -306,7 +306,7 @@ where publications.push(async { client .receive( - &mut ClientCtx::from(context::SharedContext::current()), + &mut ClientCtx::from(context::Context::current()), topic.clone(), message.clone(), ) @@ -356,7 +356,7 @@ pub fn init_tracing( async fn main() -> anyhow::Result<()> { let tracer_provider = init_tracing("Pub/Sub")?; - let addrs = Publisher:: { + let addrs = Publisher:: { clients: Arc::new(Mutex::new(HashMap::new())), subscriptions: Arc::new(RwLock::new(HashMap::new())), } @@ -383,7 +383,7 @@ async fn main() -> anyhow::Result<()> { publisher .publish( - &mut SharedContext::current(), + &mut context::Context::current(), "calculus".into(), "sqrt(2)".into(), ) @@ -391,7 +391,7 @@ async fn main() -> anyhow::Result<()> { publisher .publish( - &mut SharedContext::current(), + &mut context::Context::current(), "cool shorts".into(), "hello to all".into(), ) @@ -399,7 +399,7 @@ async fn main() -> anyhow::Result<()> { publisher .publish( - &mut SharedContext::current(), + &mut context::Context::current(), "history".into(), "napoleon".to_string(), ) @@ -409,7 +409,7 @@ async fn main() -> anyhow::Result<()> { publisher .publish( - &mut SharedContext::current(), + &mut context::Context::current(), "cool shorts".into(), "hello to who?".into(), ) diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index 8c8d6619e..ff0307d39 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -5,7 +5,7 @@ // https://opensource.org/licenses/MIT. use futures::prelude::*; -use tarpc::context::SharedContext; +use tarpc::context::Context; use tarpc::{ ClientMessage, client, context, server::{self, Channel}, @@ -25,7 +25,7 @@ pub trait World { struct HelloServer; impl World for HelloServer { - type Context = SharedContext; + type Context = context::Context; async fn hello(self, _: &mut Self::Context, name: String) -> String { format!("Hello, {name}!") } @@ -50,7 +50,7 @@ async fn main() -> anyhow::Result<()> { // args as defined, with the addition of a Context, which is always the first arg. The Context // specifies a deadline and trace information which can be helpful in debugging requests. let hello = client - .hello(&mut context::SharedContext::current(), "Stim".to_string()) + .hello(&mut context::Context::current(), "Stim".to_string()) .await?; println!("{hello}"); diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index d67340449..07eb6a8ef 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -10,8 +10,8 @@ use std::io::{self, BufReader, Cursor}; use std::net::{IpAddr, Ipv4Addr}; use std::sync::Arc; -use tarpc::context::SharedContext; -use tarpc::serde_transport as transport; +use tarpc::context::Context; +use tarpc::{context, serde_transport as transport}; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; @@ -32,7 +32,7 @@ pub trait PingService { struct Service; impl PingService for Service { - type Context = SharedContext; + type Context = context::Context; async fn ping(self, _: &mut Self::Context) -> String { "🔒".to_owned() } @@ -146,7 +146,7 @@ async fn main() -> anyhow::Result<()> { let transport = transport::new(codec_builder.new_framed(stream), Bincode::default()); let answer = PingServiceClient::new(Default::default(), transport) .spawn() - .ping(&mut SharedContext::current()) + .ping(&mut context::Context::current()) .await?; println!("ping answer: {answer}"); diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index 77b19ba46..abf1cbbc1 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -20,7 +20,7 @@ use std::{ atomic::{AtomicBool, Ordering}, }, }; -use tarpc::context::{ExtractContext, SharedContext}; +use tarpc::context::{ExtractContext}; use tarpc::{ ClientMessage, RequestName, Response, ServerError, Transport, client::{ @@ -58,7 +58,7 @@ pub mod double { struct AddServer; impl AddService for AddServer { - type Context = SharedContext; + type Context = context::Context; async fn add(self, _: &mut Self::Context, x: i32, y: i32) -> i32 { x + y } @@ -73,13 +73,13 @@ struct DoubleServer { impl DoubleService for DoubleServer where Stub: AddStub + Clone + Send + Sync + 'static, - ClientCtx: From + Send + Sync + 'static, + ClientCtx: From + Send + Sync + 'static, { - type Context = SharedContext; + type Context = context::Context; async fn double(self, _: &mut Self::Context, x: i32) -> Result { self.add_client .add( - &mut ClientCtx::from(context::SharedContext::current()), + &mut ClientCtx::from(context::Context::current()), x, x, ) @@ -145,7 +145,7 @@ fn make_stub( where Req: RequestName + Send + Sync + 'static, Resp: Send + Sync + 'static, - ClientCtx: ExtractContext + From + Send + Sync + 'static, + ClientCtx: ExtractContext + From + Send + Sync + 'static, { let stub = load_balance::RoundRobin::new( backends @@ -200,7 +200,7 @@ async fn main() -> anyhow::Result<()> { .filter_map(|r| future::ready(r.ok())); let addr = double_listener.get_ref().local_addr(); let double_server = double_listener.map(BaseChannel::with_defaults).take(1); - let server = DoubleServer::<_, SharedContext> { + let server = DoubleServer::<_, context::Context> { add_client, ghost: PhantomData, } @@ -215,7 +215,7 @@ async fn main() -> anyhow::Result<()> { tracing::info!( "{:?}", double_client - .double(&mut context::SharedContext::current(), 1) + .double(&mut context::Context::current(), 1) .await? ); } diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 27856c729..fab3b2548 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -9,7 +9,6 @@ mod in_flight_requests; pub mod stub; -use crate::context::{ExtractContext, SharedContext}; use crate::{ ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport, cancellations::{CanceledRequests, RequestCancellation, cancellations}, @@ -33,6 +32,7 @@ use std::{ }; use tokio::sync::{mpsc, oneshot}; use tracing::Span; +use crate::context::ExtractContext; /// Settings that control the behavior of the client. #[derive(Clone, Debug)] @@ -122,7 +122,7 @@ impl Clone for Channel { impl Channel where Req: RequestName, - ClientCtx: ExtractContext, + ClientCtx: ExtractContext, { /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that /// resolves to the response. @@ -184,7 +184,7 @@ where /// A server response that is completed by request dispatch when the corresponding response /// arrives off the wire. struct ResponseGuard<'a, Resp> { - response: &'a mut oneshot::Receiver>, + response: &'a mut oneshot::Receiver>, cancellation: &'a RequestCancellation, request_id: u64, cancel: bool, @@ -212,7 +212,7 @@ pub enum RpcError { } impl ResponseGuard<'_, Resp> { - async fn response(mut self) -> Result<(SharedContext, Resp), RpcError> { + async fn response(mut self) -> Result<(context::Context, Resp), RpcError> { let response = (&mut self.response).await; // Cancel drop logic once a response has been received. self.cancel = false; @@ -308,7 +308,7 @@ pub struct RequestDispatch { impl RequestDispatch where C: Transport, Response>, - ClientCtx: ExtractContext + From, + ClientCtx: ExtractContext + From, { fn in_flight_requests<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests { self.as_mut().project().in_flight_requests @@ -673,7 +673,7 @@ where impl Future for RequestDispatch where C: Transport, Response>, - ClientCtx: ExtractContext + From, + ClientCtx: ExtractContext + From, { type Output = Result<(), ChannelError>; @@ -704,12 +704,12 @@ where /// the lifecycle of the request. #[derive(Debug)] struct DispatchRequest { - pub ctx: context::SharedContext, + pub ctx: context::Context, ///TODO: <-- this should be a &mut ClientContext pub span: Span, pub request_id: u64, pub request: Req, - pub response_completion: oneshot::Sender>, + pub response_completion: oneshot::Sender>, } #[cfg(test)] @@ -717,12 +717,7 @@ mod tests { use super::{ Channel, DispatchRequest, RequestDispatch, ResponseGuard, RpcError, cancellations, }; - use crate::context::SharedContext; - use crate::{ - ChannelError, ClientMessage, Response, - client::{Config, in_flight_requests::InFlightRequests}, - transport::{self, channel::UnboundedChannel}, - }; + use crate::{ChannelError, ClientMessage, Response, client::{Config, in_flight_requests::InFlightRequests}, transport::{self, channel::UnboundedChannel}, context}; use assert_matches::assert_matches; use futures::{prelude::*, task::*}; use std::{ @@ -748,7 +743,7 @@ mod tests { let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); - let context = SharedContext::current(); + let context = context::Context::current(); dispatch .in_flight_requests @@ -763,7 +758,7 @@ mod tests { server_channel .send(Response { request_id: 0, - context: SharedContext::current(), + context: context::Context::current(), message: Ok("Resp".into()), }) .await @@ -791,7 +786,7 @@ mod tests { async fn dispatch_response_doesnt_cancel_after_complete() { let (cancellation, mut canceled_requests) = cancellations(); let (tx, mut response) = oneshot::channel(); - tx.send(Ok((SharedContext::current(), "well done"))) + tx.send(Ok((context::Context::current(), "well done"))) .unwrap(); // resp's drop() is run, but should not send a cancel message. ResponseGuard { @@ -810,7 +805,7 @@ mod tests { #[tokio::test] async fn stage_request() { - let (mut dispatch, mut channel, _server_channel) = set_up::(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); @@ -840,7 +835,7 @@ mod tests { &mut server_channel, Response { request_id: 0, - context: SharedContext::current(), + context: context::Context::current(), message: Ok("hello".into()), }, ) @@ -851,7 +846,7 @@ mod tests { #[allow(unstable_name_collisions)] #[tokio::test] async fn stage_request_response_future_dropped_is_canceled_before_sending() { - let (mut dispatch, mut channel, _server_channel) = set_up::(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); @@ -867,7 +862,7 @@ mod tests { #[allow(unstable_name_collisions)] #[tokio::test] async fn stage_request_response_future_dropped_is_canceled_after_sending() { - let (mut dispatch, mut channel, _server_channel) = set_up::(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); @@ -888,7 +883,7 @@ mod tests { #[tokio::test] async fn stage_request_response_closed_skipped() { - let (mut dispatch, mut channel, _server_channel) = set_up::(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); @@ -905,7 +900,7 @@ mod tests { async fn test_permit_before_transport_error() { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); let (mut dispatch, mut channel, mut cx) = - set_up_always_err::(TransportError::Flush); + set_up_always_err::(TransportError::Flush); let (tx, mut rx) = oneshot::channel(); // reserve succeeds let permit = reserve_for_send(&mut channel, tx, &mut rx).await; @@ -922,11 +917,11 @@ mod tests { #[tokio::test] async fn test_shutdown() { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); - let (dispatch, channel, _server_channel) = set_up::(); + let (dispatch, channel, _server_channel) = set_up::(); drop(dispatch); // error on send let resp = channel - .call(&mut SharedContext::current(), "hi".to_string()) + .call(&mut context::Context::current(), "hi".to_string()) .await; assert_matches!(resp, Err(RpcError::Shutdown)); } @@ -934,7 +929,7 @@ mod tests { #[tokio::test] async fn test_transport_error_write() { let cause = TransportError::Write; - let (mut dispatch, mut channel, mut cx) = set_up_always_err::(cause); + let (mut dispatch, mut channel, mut cx) = set_up_always_err::(cause); let (tx, mut rx) = oneshot::channel(); let resp = send_request(&mut channel, "hi", tx, &mut rx).await; @@ -957,7 +952,7 @@ mod tests { #[tokio::test] async fn test_transport_error_read() { let cause = TransportError::Read; - let (mut dispatch, mut channel, mut cx) = set_up_always_err::(cause); + let (mut dispatch, mut channel, mut cx) = set_up_always_err::(cause); let (tx, mut rx) = oneshot::channel(); let resp = send_request(&mut channel, "hi", tx, &mut rx).await; assert_eq!( @@ -974,7 +969,7 @@ mod tests { #[tokio::test] async fn test_transport_error_ready() { let cause = TransportError::Ready; - let (mut dispatch, _, mut cx) = set_up_always_err::(cause); + let (mut dispatch, _, mut cx) = set_up_always_err::(cause); assert_eq!( dispatch.as_mut().poll(&mut cx), Poll::Ready(Err(ChannelError::Ready(Arc::new(cause)))) @@ -984,7 +979,7 @@ mod tests { #[tokio::test] async fn test_transport_error_flush() { let cause = TransportError::Flush; - let (mut dispatch, _, mut cx) = set_up_always_err::(cause); + let (mut dispatch, _, mut cx) = set_up_always_err::(cause); assert_eq!( dispatch.as_mut().poll(&mut cx), Poll::Ready(Err(ChannelError::Flush(Arc::new(cause)))) @@ -994,7 +989,7 @@ mod tests { #[tokio::test] async fn test_transport_error_close() { let cause = TransportError::Close; - let (mut dispatch, channel, mut cx) = set_up_always_err::(cause); + let (mut dispatch, channel, mut cx) = set_up_always_err::(cause); drop(channel); assert_eq!( dispatch.as_mut().poll(&mut cx), @@ -1140,13 +1135,13 @@ mod tests { async fn send_request<'a, ClientCtx>( channel: &'a mut Channel, request: &str, - response_completion: oneshot::Sender>, - response: &'a mut oneshot::Receiver>, + response_completion: oneshot::Sender>, + response: &'a mut oneshot::Receiver>, ) -> ResponseGuard<'a, String> { let request_id = u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); let request = DispatchRequest { - ctx: SharedContext::current(), + ctx: context::Context::current(), span: Span::current(), request_id, request: request.to_string(), @@ -1164,15 +1159,15 @@ mod tests { async fn reserve_for_send<'a, ClientCtx>( channel: &'a mut Channel, - response_completion: oneshot::Sender>, - response: &'a mut oneshot::Receiver>, + response_completion: oneshot::Sender>, + response: &'a mut oneshot::Receiver>, ) -> impl FnOnce(&str) -> ResponseGuard<'a, String> { let permit = channel.to_dispatch.reserve().await.unwrap(); |request| { let request_id = u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); let request = DispatchRequest { - ctx: SharedContext::current(), + ctx: context::Context::current(), span: Span::current(), request_id, request: request.to_string(), diff --git a/tarpc/src/client/in_flight_requests.rs b/tarpc/src/client/in_flight_requests.rs index 0ea5ba5ac..cc5091fc6 100644 --- a/tarpc/src/client/in_flight_requests.rs +++ b/tarpc/src/client/in_flight_requests.rs @@ -1,9 +1,5 @@ use crate::client::RpcError; -use crate::context::SharedContext; -use crate::{ - trace, - util::{Compact, TimeUntil}, -}; +use crate::{context, trace, util::{Compact, TimeUntil}}; use fnv::FnvHashMap; use std::time::Instant; use std::{ @@ -34,7 +30,7 @@ impl Default for InFlightRequests { struct RequestData { ctx: trace::Context, span: Span, - response_completion: oneshot::Sender>, + response_completion: oneshot::Sender>, /// The key to remove the timer for the request's deadline. deadline_key: delay_queue::Key, } @@ -62,7 +58,7 @@ impl InFlightRequests { ctx: trace::Context, deadline: Instant, span: Span, - response_completion: oneshot::Sender>, + response_completion: oneshot::Sender>, ) -> Result<(), AlreadyExistsError> { match self.request_data.entry(request_id) { hash_map::Entry::Vacant(vacant) => { @@ -84,7 +80,7 @@ impl InFlightRequests { pub fn complete_request( &mut self, request_id: u64, - result: Result<(SharedContext, Res), RpcError>, + result: Result<(context::Context, Res), RpcError>, ) -> Option { if let Some(request_data) = self.request_data.remove(&request_id) { self.request_data.compact(0.1); @@ -103,7 +99,7 @@ impl InFlightRequests { /// Returns Spans for all completes requests. pub fn complete_all_requests<'a>( &'a mut self, - mut result: impl FnMut() -> Result<(SharedContext, Res), RpcError> + 'a, + mut result: impl FnMut() -> Result<(context::Context, Res), RpcError> + 'a, ) -> impl Iterator + 'a { self.deadlines.clear(); self.request_data.drain().map(move |(_, request_data)| { @@ -129,7 +125,7 @@ impl InFlightRequests { pub fn poll_expired( &mut self, cx: &mut Context, - expired_error: impl Fn() -> Result<(SharedContext, Res), RpcError>, + expired_error: impl Fn() -> Result<(context::Context, Res), RpcError>, ) -> Poll> { self.deadlines.poll_expired(cx).map(|expired| { let request_id = expired?.into_inner(); diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index 51cececae..1a0dbfff6 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -1,11 +1,7 @@ //! Provides a Stub trait, implemented by types that can call remote services. -use crate::context::{ExtractContext, SharedContext}; -use crate::{ - RequestName, - client::{Channel, RpcError}, - server::Serve, -}; +use crate::context::{ExtractContext}; +use crate::{RequestName, client::{Channel, RpcError}, server::Serve, context}; pub mod load_balance; pub mod retry; @@ -37,7 +33,7 @@ pub trait Stub { impl Stub for Channel where Req: RequestName, - ClientCtx: ExtractContext, + ClientCtx: ExtractContext, { type Req = Req; type Resp = Resp; @@ -50,11 +46,11 @@ where impl Stub for S where - S: Serve + Clone, + S: Serve + Clone, { type Req = S::Req; type Resp = S::Resp; - type ClientCtx = SharedContext; + type ClientCtx = context::Context; async fn call( &self, ctx: &mut Self::ClientCtx, diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs index 9664a2aa7..4b9d9df3a 100644 --- a/tarpc/src/client/stub/load_balance.rs +++ b/tarpc/src/client/stub/load_balance.rs @@ -197,17 +197,17 @@ mod consistent_hash { for _ in 0..2 { let resp = stub - .call(&mut context::SharedContext::current(), 'a') + .call(&mut context::Context::current(), 'a') .await?; assert_eq!(resp, 1); let resp = stub - .call(&mut context::SharedContext::current(), 'b') + .call(&mut context::Context::current(), 'b') .await?; assert_eq!(resp, 2); let resp = stub - .call(&mut context::SharedContext::current(), 'c') + .call(&mut context::Context::current(), 'c') .await?; assert_eq!(resp, 3); } diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index bc357e50f..6db79e49b 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -23,7 +23,7 @@ use tracing_opentelemetry::OpenTelemetrySpanExt; /// be different for each request in scope. #[derive(Debug, Clone)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct SharedContext { +pub struct Context { /// When the client expects the request to be complete by. The server should cancel the request /// if it is not complete by this time. #[cfg_attr(feature = "serde1", serde(default = "ten_seconds_from_now"))] @@ -111,7 +111,7 @@ mod absolute_to_relative_time { } } -assert_impl_all!(SharedContext: Send, Sync); +assert_impl_all!(Context: Send, Sync); fn ten_seconds_from_now() -> Instant { Instant::now() + Duration::from_secs(10) @@ -126,7 +126,7 @@ impl Default for Deadline { } } -impl SharedContext { +impl Context { /// Returns the context for the current request, or a default Context if no request is active. pub fn current() -> Self { let span = tracing::Span::current(); @@ -152,11 +152,11 @@ impl SharedContext { pub(crate) trait SpanExt { /// Sets the given context on this span. Newly-created spans will be children of the given /// context's trace context. - fn set_context(&self, context: &SharedContext); + fn set_context(&self, context: &Context); } impl SpanExt for tracing::Span { - fn set_context(&self, context: &SharedContext) { + fn set_context(&self, context: &Context) { self.set_parent( opentelemetry::Context::new() .with_remote_span_context(opentelemetry::trace::SpanContext::new( diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index fc79e3056..cb5a64085 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -124,7 +124,7 @@ //! struct HelloServer; //! //! impl World for HelloServer { -//! type Context = context::SharedContext; +//! type Context = context::Context; //! // Each defined rpc generates an async fn that serves the RPC //! async fn hello(self, _: &mut Self::Context, name: String) -> String { //! format!("Hello, {name}!") @@ -145,7 +145,6 @@ //! # use tarpc::{ //! # ClientMessage, //! # client, context, -//! # context::{SharedContext}, //! # transport::channel, //! # server::{self, Channel}, //! # }; @@ -161,7 +160,7 @@ //! # #[derive(Clone)] //! # struct HelloServer; //! # impl World for HelloServer { -//! # type Context = SharedContext; +//! # type Context = context::Context; //! # // Each defined rpc generates an async fn that serves the RPC //! # async fn hello(self, _: &mut Self::Context, name: String) -> String { //! # format!("Hello, {name}!") @@ -184,12 +183,12 @@ //! //! // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` //! // that takes a config and any Transport as input. -//! let mut client = WorldClient::::new(client::Config::default(), client_transport).spawn(); +//! let mut client = WorldClient::::new(client::Config::default(), client_transport).spawn(); //! //! // The client has an RPC method for each RPC defined in the annotated trait. It takes the same //! // args as defined, with the addition of a Context, which is always the first arg. The Context //! // specifies a deadline and trace information which can be helpful in debugging requests. -//! let mut context = context::SharedContext::current(); +//! let mut context = context::Context::current(); //! let hello = client.hello(&mut context, "Stim".to_string()).await?; //! //! println!("{hello}"); @@ -256,7 +255,6 @@ pub(crate) mod util; pub use crate::transport::sealed::Transport; -use crate::context::SharedContext; use std::ops::Deref; use std::{any::Any, error::Error, io, sync::Arc, time::Instant}; @@ -543,7 +541,7 @@ impl ServerError { impl Request where - Ctx: Deref, + Ctx: Deref, { /// Returns the deadline for this request. pub fn deadline(&self) -> &Instant { diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 1ed69fcd8..d1a384e32 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -6,11 +6,10 @@ //! Provides a server that concurrently handles many connections sending multiplexed requests. -use crate::context::{ExtractContext, SharedContext}; use crate::{ ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport, cancellations::{CanceledRequests, RequestCancellation, cancellations}, - context::SpanExt, + context, context::SpanExt, trace, util::TimeUntil, }; @@ -28,6 +27,7 @@ use std::{ convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin, sync::Arc, time::SystemTime, }; use tracing::{Span, info_span, instrument::Instrument}; +use crate::context::ExtractContext; mod in_flight_requests; pub mod request_hook; @@ -65,7 +65,7 @@ impl Config { ) -> BaseChannel where T: Transport, ClientMessage>, - ServerCtx: ExtractContext, + ServerCtx: ExtractContext, { BaseChannel::new(self, transport) } @@ -174,7 +174,7 @@ pub struct BaseChannel { impl BaseChannel where T: Transport, ClientMessage>, - ServerCtx: ExtractContext, + ServerCtx: ExtractContext, { /// Creates a new channel backed by `transport` and configured with `config`. pub fn new(config: Config, transport: T) -> Self { @@ -369,7 +369,6 @@ where /// use tarpc::{ /// ClientMessage, /// context, - /// context::{SharedContext}, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, serve}, /// transport, @@ -389,7 +388,7 @@ where /// tokio::spawn(request.execute(serve(|_, i| async move { Ok(i + 1) }.boxed()))); /// } /// }); - /// let mut context = context::SharedContext::current(); + /// let mut context = context::Context::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` @@ -413,7 +412,7 @@ where /// # Example /// /// ```rust - /// use tarpc::{ClientMessage, context, client, server::{self, BaseChannel, Channel, serve}, transport, context::{SharedContext}}; + /// use tarpc::{ClientMessage, context, client, server::{self, BaseChannel, Channel, serve}, transport}; /// use futures::prelude::*; /// use tracing_subscriber::prelude::*; /// @@ -430,7 +429,7 @@ where /// .for_each(|response| async move { /// tokio::spawn(response); /// }.boxed())); - /// let mut context = context::SharedContext::current(); + /// let mut context = context::Context::current(); /// assert_eq!( /// client.call(&mut context, 1).await.unwrap(), /// 2); @@ -449,7 +448,7 @@ where impl Stream for BaseChannel where T: Transport, ClientMessage>, - ServerCtx: ExtractContext, + ServerCtx: ExtractContext, { type Item = Result, ChannelError>; @@ -559,7 +558,7 @@ impl Sink> where T: Transport, ClientMessage>, T::Error: Error, - ServerCtx: ExtractContext, + ServerCtx: ExtractContext, { type Error = ChannelError; @@ -615,7 +614,7 @@ impl AsRef for BaseChannel impl Channel for BaseChannel where T: Transport, ClientMessage>, - ServerCtx: ExtractContext, + ServerCtx: ExtractContext, { type Req = Req; type Resp = Resp; @@ -773,7 +772,6 @@ where /// /// ```rust /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport, ClientMessage}; - /// use tarpc::context::{SharedContext}; /// use futures::prelude::*; /// /// # #[cfg(not(feature = "tokio1"))] @@ -789,7 +787,7 @@ where /// .for_each(|response| async move { /// tokio::spawn(response); /// }.boxed())); - /// let mut context = context::SharedContext::current(); + /// let mut context = context::Context::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` @@ -878,7 +876,6 @@ impl InFlightRequest { /// use tarpc::{ /// ClientMessage, /// context, - /// context::{SharedContext}, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, serve}, /// transport, @@ -898,7 +895,7 @@ impl InFlightRequest { /// in_flight_request.execute(serve(|_, i| async move { Ok(i + 1) }.boxed())).await; /// } /// }); - /// let mut context = context::SharedContext::current(); + /// let mut context = context::Context::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` @@ -1001,7 +998,7 @@ mod tests { request_hook::{AfterRequest, BeforeRequest, RequestHook}, serve, }; - use crate::context::{ExtractContext, SharedContext}; + use crate::context::{ExtractContext}; use crate::{ ClientMessage, Request, Response, ServerError, context, trace, transport::channel::{self, UnboundedChannel}, @@ -1027,14 +1024,14 @@ mod tests { Req, Resp, UnboundedChannel< - ClientMessage, - Response, + ClientMessage, + Response, >, - SharedContext, + context::Context, >, >, >, - UnboundedChannel, ClientMessage>, + UnboundedChannel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::unbounded(); (Box::pin(BaseChannel::new(Config::default(), rx)), tx) @@ -1048,15 +1045,15 @@ mod tests { Req, Resp, UnboundedChannel< - ClientMessage, - Response, + ClientMessage, + Response, >, - SharedContext, + context::Context, >, >, >, >, - UnboundedChannel, ClientMessage>, + UnboundedChannel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::unbounded(); ( @@ -1075,15 +1072,15 @@ mod tests { Req, Resp, channel::Channel< - ClientMessage, - Response, + ClientMessage, + Response, >, - SharedContext, + context::Context, >, >, >, >, - channel::Channel, ClientMessage>, + channel::Channel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::bounded(capacity); // Add 1 because capacity 0 is not supported (but is supported by transport::channel::bounded). @@ -1093,9 +1090,9 @@ mod tests { (Box::pin(BaseChannel::new(config, rx).requests()), tx) } - fn fake_request(req: Req) -> ClientMessage { + fn fake_request(req: Req) -> ClientMessage { ClientMessage::Request(Request { - context: context::SharedContext::current(), + context: context::Context::current(), id: 0, message: req, }) @@ -1111,7 +1108,7 @@ mod tests { async fn test_serve() { let serve = serve(|_, i| async move { Ok(i) }.boxed()); assert_matches!( - serve.serve(&mut context::SharedContext::current(), 7).await, + serve.serve(&mut context::Context::current(), 7).await, Ok(7) ); } @@ -1121,7 +1118,7 @@ mod tests { struct SetDeadline(Instant); impl BeforeRequest for SetDeadline where - ServerCtx: ExtractContext, + ServerCtx: ExtractContext, { async fn before(&mut self, ctx: &mut ServerCtx, _: &Req) -> Result<(), ServerError> { let mut inner = ctx.extract(); @@ -1134,7 +1131,7 @@ mod tests { let some_time = Instant::now() + Duration::from_secs(37); let some_other_time = Instant::now() + Duration::from_secs(83); - let serve = serve(move |ctx: &mut context::SharedContext, i| { + let serve = serve(move |ctx: &mut context::Context, i| { async move { assert_eq!(ctx.deadline, some_time); Ok(i) @@ -1142,7 +1139,7 @@ mod tests { .boxed() }); let deadline_hook = serve.before(SetDeadline(some_time)); - let mut ctx = context::SharedContext::current(); + let mut ctx = context::Context::current(); ctx.deadline = some_other_time; deadline_hook.serve(&mut ctx, 7).await?; Ok(()) @@ -1174,10 +1171,10 @@ mod tests { } } - let serve = serve(move |_: &mut context::SharedContext, i| async move { Ok(i) }.boxed()); + let serve = serve(move |_: &mut context::Context, i| async move { Ok(i) }.boxed()); serve .before_and_after(PrintLatency::new()) - .serve(&mut context::SharedContext::current(), 7) + .serve(&mut context::Context::current(), 7) .await?; Ok(()) } @@ -1185,11 +1182,11 @@ mod tests { #[tokio::test] async fn serve_before_error_aborts_request() -> anyhow::Result<()> { let serve = serve(|_, _| async { panic!("Shouldn't get here") }.boxed()); - let deadline_hook = serve.before(|_: &mut context::SharedContext, _: &i32| async { + let deadline_hook = serve.before(|_: &mut context::Context, _: &i32| async { Err(ServerError::new(io::ErrorKind::Other, "oops".into())) }); let resp: Result = deadline_hook - .serve(&mut context::SharedContext::current(), 7) + .serve(&mut context::Context::current(), 7) .await; assert_matches!(resp, Err(_)); Ok(()) @@ -1203,14 +1200,14 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::Context::current(), message: (), }) .unwrap(); assert_matches!( channel.as_mut().start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::Context::current(), message: () }), Err(AlreadyExistsError) @@ -1226,7 +1223,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::Context::current(), message: (), }) .unwrap(); @@ -1234,7 +1231,7 @@ mod tests { .as_mut() .start_request(Request { id: 1, - context: context::SharedContext::current(), + context: context::Context::current(), message: (), }) .unwrap(); @@ -1257,7 +1254,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::Context::current(), message: (), }) .unwrap(); @@ -1286,7 +1283,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::Context::current(), message: (), }) .unwrap(); @@ -1328,7 +1325,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::Context::current(), message: (), }) .unwrap(); @@ -1351,7 +1348,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: SharedContext::current(), + context: context::Context::current(), message: (), }) .unwrap(); @@ -1360,7 +1357,7 @@ mod tests { .as_mut() .start_send(Response { request_id: 0, - context: SharedContext::current(), + context: context::Context::current(), message: Ok(()), }) .unwrap(); @@ -1419,7 +1416,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::Context::current(), message: (), }) .unwrap(); @@ -1428,7 +1425,7 @@ mod tests { .channel_pin_mut() .start_send(Response { request_id: 0, - context: SharedContext::current(), + context: context::Context::current(), message: Ok(()), }) .unwrap(); @@ -1440,7 +1437,7 @@ mod tests { .responses_tx .send(Response { request_id: 1, - context: SharedContext::current(), + context: context::Context::current(), message: Ok(()), }) .await @@ -1451,7 +1448,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 1, - context: SharedContext::current(), + context: context::Context::current(), message: (), }) .unwrap(); @@ -1472,7 +1469,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 0, - context: context::SharedContext::current(), + context: context::Context::current(), message: (), }) .unwrap(); @@ -1481,7 +1478,7 @@ mod tests { .channel_pin_mut() .start_send(Response { request_id: 0, - context: SharedContext::current(), + context: context::Context::current(), message: Ok(()), }) .unwrap(); @@ -1492,7 +1489,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 1, - context: SharedContext::current(), + context: context::Context::current(), message: (), }) .unwrap(); @@ -1502,7 +1499,7 @@ mod tests { .responses_tx .send(Response { request_id: 1, - context: SharedContext::current(), + context: context::Context::current(), message: Ok(()), }) .await diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index 6a71124b1..67d46e330 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -50,7 +50,6 @@ where /// use tarpc::{ /// ClientMessage, /// context, -/// context::{SharedContext}, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, incoming::{Incoming, spawn_incoming}, serve}, /// transport, @@ -67,7 +66,7 @@ where /// BaseChannel::new(server::Config::default(), rx) /// }).execute(serve(|_, i| async move { Ok(i + 1) }.boxed())); /// tokio::spawn(spawn_incoming(incoming)); -/// let mut context = context::SharedContext::current(); +/// let mut context = context::Context::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` diff --git a/tarpc/src/server/limits/requests_per_channel.rs b/tarpc/src/server/limits/requests_per_channel.rs index 34b372510..32b126aa6 100644 --- a/tarpc/src/server/limits/requests_per_channel.rs +++ b/tarpc/src/server/limits/requests_per_channel.rs @@ -180,7 +180,6 @@ where mod tests { use super::*; - use crate::context::SharedContext; use crate::server::{ TrackedRequest, testing::{self, FakeChannel, PollExt}, @@ -191,6 +190,7 @@ mod tests { time::{Duration, Instant}, }; use tracing::Span; + use crate::context; #[tokio::test] async fn throttler_in_flight_requests() { @@ -271,8 +271,8 @@ mod tests { } impl PendingSink<(), ()> { pub fn default() -> PendingSink< - io::Result>, - Response, + io::Result>, + Response, > { PendingSink { ghost: PhantomData } } @@ -300,14 +300,14 @@ mod tests { } impl Channel for PendingSink< - io::Result>, - Response, + io::Result>, + Response, > { type Req = Req; type Resp = Resp; type Transport = (); - type ServerCtx = SharedContext; + type ServerCtx = context::Context; fn config(&self) -> &Config { unimplemented!() } @@ -337,7 +337,7 @@ mod tests { .as_mut() .start_send(Response { request_id: 0, - context: SharedContext::current(), + context: context::Context::current(), message: Ok(1), }) .unwrap(); diff --git a/tarpc/src/server/request_hook.rs b/tarpc/src/server/request_hook.rs index 4f3d60377..090c4a72c 100644 --- a/tarpc/src/server/request_hook.rs +++ b/tarpc/src/server/request_hook.rs @@ -48,7 +48,7 @@ pub trait RequestHook: Serve { /// use std::io; /// /// let serve = serve(|_ctx, i| async move { Ok(i + 1) }.boxed()) - /// .before(|_ctx: &mut context::SharedContext, req: &i32| { + /// .before(|_ctx: &mut context::Context, req: &i32| { /// future::ready( /// if *req == 1 { /// Err(ServerError::new( @@ -58,7 +58,7 @@ pub trait RequestHook: Serve { /// Ok(()) /// }) /// }); - /// let mut context = context::SharedContext::current(); + /// let mut context = context::Context::current(); /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` @@ -95,13 +95,13 @@ pub trait RequestHook: Serve { /// Ok(i + 1) /// } /// }.boxed()) - /// .after(|_ctx: &mut context::SharedContext, resp: &mut Result| { + /// .after(|_ctx: &mut context::Context, resp: &mut Result| { /// if let Err(e) = resp { /// eprintln!("server error: {e:?}"); /// } /// future::ready(()) /// }); - /// let mut context = context::SharedContext::current(); + /// let mut context = context::Context::current(); /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` @@ -153,7 +153,7 @@ pub trait RequestHook: Serve { /// let serve = serve(|_ctx, i| async move { /// Ok(i + 1) /// }.boxed()).before_and_after(PrintLatency(Instant::now())); - /// let mut context = context::SharedContext::current(); + /// let mut context = context::Context::current(); /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_ok()); /// ``` diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs index 1552a0b49..df4873e83 100644 --- a/tarpc/src/server/request_hook/before.rs +++ b/tarpc/src/server/request_hook/before.rs @@ -129,7 +129,7 @@ where /// Ok(()) /// }) /// .serving(serve(|_ctx, i| async move { Ok(i + 1) }.boxed())); -/// let mut context = context::SharedContext::current(); +/// let mut context = context::Context::current(); /// let response = serve.clone().serve(&mut context, 1); /// assert!(block_on(response).is_ok()); /// assert!(i.get() == 2); @@ -219,7 +219,7 @@ fn before_request_list() { Ok(()) }) .serving(serve(|_ctx, i| async move { Ok(i + 1) }.boxed())); - let mut context = crate::context::SharedContext::current(); + let mut context = crate::context::Context::current(); let response = serve.clone().serve(&mut context, 1); assert!(block_on(response).is_ok()); assert!(i.get() == 2); diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index ce409dd85..047464cf5 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -4,7 +4,6 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -use crate::context::SharedContext; use crate::{ Request, Response, cancellations::{CanceledRequests, RequestCancellation, cancellations}, @@ -39,8 +38,8 @@ where } } -impl Sink> - for FakeChannel> +impl Sink> + for FakeChannel> { type Error = io::Error; @@ -50,7 +49,7 @@ impl Sink> fn start_send( mut self: Pin<&mut Self>, - response: Response, + response: Response, ) -> Result<(), Self::Error> { self.as_mut() .project() @@ -72,14 +71,14 @@ impl Sink> } impl Channel - for FakeChannel>, Response> + for FakeChannel>, Response> where Req: Unpin, { type Req = Req; type Resp = Resp; type Transport = (); - type ServerCtx = SharedContext; + type ServerCtx = context::Context; fn config(&self) -> &Config { &self.config @@ -95,14 +94,14 @@ where } impl - FakeChannel>, Response> + FakeChannel>, Response> { pub fn push_req(&mut self, id: u64, message: Req) { let (_, abort_registration) = futures::future::AbortHandle::new_pair(); let (request_cancellation, _) = cancellations(); self.stream.push_back(Ok(TrackedRequest { request: Request { - context: context::SharedContext { + context: context::Context { deadline: Instant::now(), trace_context: Default::default(), }, @@ -122,7 +121,7 @@ impl impl FakeChannel<(), ()> { pub fn default() - -> FakeChannel>, Response> + -> FakeChannel>, Response> { let (request_cancellation, canceled_requests) = cancellations(); diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index a698136f0..1ff75e70d 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -161,7 +161,6 @@ impl Sink for Channel { #[cfg(all(test, feature = "tokio1"))] mod tests { - use crate::context::SharedContext; use crate::{ ServerError, client::{self, RpcError}, @@ -193,7 +192,7 @@ mod tests { tokio::spawn( stream::once(future::ready(server_channel)) .map(BaseChannel::with_defaults) - .execute(serve(|_ctx: &mut SharedContext, request: String| { + .execute(serve(|_ctx: &mut context::Context, request: String| { async move { request.parse::().map_err(|_| { ServerError::new( @@ -212,10 +211,10 @@ mod tests { let client = client::new(client::Config::default(), client_channel).spawn(); let response1 = client - .call(&mut context::SharedContext::current(), "123".into()) + .call(&mut context::Context::current(), "123".into()) .await; let response2 = client - .call(&mut context::SharedContext::current(), "abc".into()) + .call(&mut context::Context::current(), "abc".into()) .await; trace!("response1: {:?}, response2: {:?}", response1, response2); diff --git a/tarpc/tests/compile_fail/must_use_request_dispatch.rs b/tarpc/tests/compile_fail/must_use_request_dispatch.rs index a5238fe8b..812fc4ee7 100644 --- a/tarpc/tests/compile_fail/must_use_request_dispatch.rs +++ b/tarpc/tests/compile_fail/must_use_request_dispatch.rs @@ -1,5 +1,5 @@ use tarpc::client; -use tarpc::context::SharedContext; +use tarpc::context::Context; #[tarpc::service] trait World { async fn hello(name: String) -> String; @@ -10,6 +10,6 @@ fn main() { #[deny(unused_must_use)] { - WorldClient::::new(client::Config::default(), client_transport).dispatch; + WorldClient::::new(client::Config::default(), client_transport).dispatch; } } diff --git a/tarpc/tests/compile_fail/must_use_request_dispatch.stderr b/tarpc/tests/compile_fail/must_use_request_dispatch.stderr index e0ec77ff3..4fe34df5f 100644 --- a/tarpc/tests/compile_fail/must_use_request_dispatch.stderr +++ b/tarpc/tests/compile_fail/must_use_request_dispatch.stderr @@ -1,8 +1,8 @@ error: unused `RequestDispatch` that must be used --> tests/compile_fail/must_use_request_dispatch.rs:13:9 | -13 | WorldClient::::new(client::Config::default(), client_transport).dispatch; - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +13 | WorldClient::::new(client::Config::default(), client_transport).dispatch; + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | note: the lint level is defined here --> tests/compile_fail/must_use_request_dispatch.rs:11:12 @@ -11,5 +11,5 @@ note: the lint level is defined here | ^^^^^^^^^^^^^^^ help: use `let _ = ...` to ignore the resulting value | -13 | let _ = WorldClient::::new(client::Config::default(), client_transport).dispatch; +13 | let _ = WorldClient::::new(client::Config::default(), client_transport).dispatch; | +++++++ diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index a39922666..a2e458361 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -1,5 +1,5 @@ use futures::prelude::*; -use tarpc::context::SharedContext; +use tarpc::context::Context; use tarpc::serde_transport; use tarpc::{ client, context, @@ -23,7 +23,7 @@ pub trait ColorProtocol { struct ColorServer; impl ColorProtocol for ColorServer { - type Context = SharedContext; + type Context = context::Context; async fn get_opposite_color(self, _: &mut Self::Context, color: TestData) -> TestData { match color { TestData::White => TestData::Black, @@ -55,7 +55,7 @@ async fn test_call() -> anyhow::Result<()> { let client = ColorProtocolClient::new(client::Config::default(), transport).spawn(); let color = client - .get_opposite_color(&mut context::SharedContext::current(), TestData::White) + .get_opposite_color(&mut context::Context::current(), TestData::White) .await?; assert_eq!(color, TestData::Black); diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index fd54b3db6..abe1ba0df 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -4,7 +4,7 @@ use futures::{ prelude::*, }; use std::time::{Duration, Instant}; -use tarpc::context::SharedContext; +use tarpc::context::Context; use tarpc::{ ClientMessage, client::{self}, @@ -25,7 +25,7 @@ trait Service { struct Server; impl Service for Server { - type Context = SharedContext; + type Context = context::Context; async fn add(self, _: &mut Self::Context, x: i32, y: i32) -> i32 { x + y } @@ -50,7 +50,7 @@ async fn sequential() { ); assert_eq!( client - .call(&mut context::SharedContext::current(), 1) + .call(&mut context::Context::current(), 1) .await .unwrap(), 2 @@ -68,7 +68,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { struct LoopServer; impl Loop for LoopServer { - type Context = SharedContext; + type Context = context::Context; async fn r#loop(self, _: &mut Self::Context) { loop { futures::pending!(); @@ -85,7 +85,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { tokio::spawn(async move { let client = LoopClient::new(client::Config::default(), tx).spawn(); - let mut ctx = context::SharedContext::current(); + let mut ctx = context::Context::current(); ctx.deadline = Instant::now() + Duration::from_secs(60 * 60); let _ = client.r#loop(&mut ctx).await; }); @@ -128,12 +128,12 @@ async fn serde_tcp() -> anyhow::Result<()> { assert_matches!( client - .add(&mut context::SharedContext::current(), 1, 2) + .add(&mut context::Context::current(), 1, 2) .await, Ok(3) ); assert_matches!( - client.hey(&mut context::SharedContext::current(), "Tim".to_string()).await, + client.hey(&mut context::Context::current(), "Tim".to_string()).await, Ok(ref s) if s == "Hey, Tim." ); @@ -166,10 +166,10 @@ async fn serde_uds() -> anyhow::Result<()> { // Save results using socket so we can clean the socket even if our test assertions fail let res1 = client - .add(&mut context::SharedContext::current(), 1, 2) + .add(&mut context::Context::current(), 1, 2) .await; let res2 = client - .hey(&mut context::SharedContext::current(), "Tim".to_string()) + .hey(&mut context::Context::current(), "Tim".to_string()) .await; assert_matches!(res1, Ok(3)); @@ -194,7 +194,7 @@ async fn concurrent() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let mut context = context::SharedContext::current(); + let mut context = context::Context::current(); let req1 = client.add(&mut context, 1, 2); assert_matches!(req1.await, Ok(3)); @@ -224,9 +224,9 @@ async fn concurrent_join() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let mut context1 = context::SharedContext::current(); - let mut context2 = context::SharedContext::current(); - let mut context3 = context::SharedContext::current(); + let mut context1 = context::Context::current(); + let mut context2 = context::Context::current(); + let mut context3 = context::Context::current(); let req1 = client.add(&mut context1, 1, 2); let req2 = client.add(&mut context2, 3, 4); @@ -258,8 +258,8 @@ async fn concurrent_join_all() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let mut context1 = context::SharedContext::current(); - let mut context2 = context::SharedContext::current(); + let mut context1 = context::Context::current(); + let mut context2 = context::Context::current(); let req1 = client.add(&mut context1, 1, 2); let req2 = client.add(&mut context2, 3, 4); @@ -281,7 +281,7 @@ async fn counter() -> anyhow::Result<()> { struct CountService(u32); impl Counter for &mut CountService { - type Context = SharedContext; + type Context = context::Context; async fn count(self, _: &mut Self::Context) -> u32 { self.0 += 1; self.0 @@ -301,11 +301,11 @@ async fn counter() -> anyhow::Result<()> { let client = CounterClient::new(client::Config::default(), tx).spawn(); assert_matches!( - client.count(&mut context::SharedContext::current()).await, + client.count(&mut context::Context::current()).await, Ok(1) ); assert_matches!( - client.count(&mut context::SharedContext::current()).await, + client.count(&mut context::Context::current()).await, Ok(2) ); From 116c718178158699aa9b6a15d8cbcf7845eb03ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Wed, 26 Nov 2025 16:41:53 +0100 Subject: [PATCH 18/26] cleanup --- example-service/src/client.rs | 1 + example-service/src/lib.rs | 2 ++ example-service/src/server.rs | 9 ++++----- plugins/Cargo.toml | 1 - plugins/src/lib.rs | 10 ++-------- tarpc/examples/compression.rs | 2 +- tarpc/examples/custom_transport.rs | 3 +-- tarpc/examples/pubsub.rs | 6 +++--- tarpc/examples/readme.rs | 4 ++-- tarpc/examples/tls_over_tcp.rs | 2 +- tarpc/examples/tracing.rs | 2 +- 11 files changed, 18 insertions(+), 24 deletions(-) diff --git a/example-service/src/client.rs b/example-service/src/client.rs index 64b2e0a89..e1d496f59 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -3,6 +3,7 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] use clap::Parser; use service::{WorldClient, init_tracing}; diff --git a/example-service/src/lib.rs b/example-service/src/lib.rs index 26b49e2ac..fd5031e75 100644 --- a/example-service/src/lib.rs +++ b/example-service/src/lib.rs @@ -4,6 +4,8 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] + use opentelemetry::trace::TracerProvider as _; use tracing_subscriber::{fmt::format::FmtSpan, prelude::*}; diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 302336c57..a8e3324fc 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -3,6 +3,7 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] use clap::Parser; use futures::{future, prelude::*}; @@ -11,14 +12,12 @@ use rand::{ thread_rng, }; use service::{World, init_tracing}; -use std::ops::Deref; use std::{ net::{IpAddr, Ipv6Addr, SocketAddr}, time::Duration, }; -use tarpc::context::Context; use tarpc::{ - ClientMessage, context, + context, server::{self, Channel, incoming::Incoming}, tokio_serde::formats::Json, }; @@ -67,11 +66,11 @@ async fn main() -> anyhow::Result<()> { .filter_map(|r| future::ready(r.ok())) .map(server::BaseChannel::with_defaults) // Limit channels to 1 per IP. - .max_channels_per_key(1, |t| t.transport().get_ref().peer_addr().unwrap().ip()) + .max_channels_per_key(1, |t| t.transport().peer_addr().unwrap().ip()) // serve is generated by the service attribute. It takes as input any type implementing // the generated World trait. .map(|channel| { - let server = HelloServer(channel.transport().get_ref().peer_addr().unwrap()); + let server = HelloServer(channel.transport().peer_addr().unwrap()); channel.execute(server.serve()).for_each(spawn) }) // Max 10 channels. diff --git a/plugins/Cargo.toml b/plugins/Cargo.toml index eeab84924..8be746c26 100644 --- a/plugins/Cargo.toml +++ b/plugins/Cargo.toml @@ -30,6 +30,5 @@ proc-macro = true [dev-dependencies] assert-type-eq = "0.1.0" futures = "0.3" -futures-util = "0.3.31" serde = { version = "1.0", features = ["derive"] } tarpc = { path = "../tarpc", features = ["serde1"] } diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index b8f1ff826..8e35ee49d 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -4,13 +4,9 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] #![recursion_limit = "512"] -extern crate proc_macro; -extern crate proc_macro2; -extern crate quote; -extern crate syn; - use proc_macro::TokenStream; use proc_macro2::TokenStream as TokenStream2; use quote::{ToTokens, format_ident, quote}; @@ -375,9 +371,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// # Example /// /// ```no_run -/// use tarpc::{client, transport, service, server::{self, Channel}}; -/// use futures_util::{TryStreamExt, sink::SinkExt};/// -/// use tarpc::context; +/// use tarpc::{client, transport, service, server::{self, Channel}, context::Context}; /// /// #[service] /// pub trait Calculator { diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index 3c5fa6fcf..aa12147e5 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -3,13 +3,13 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] use flate2::{Compression, read::DeflateDecoder, write::DeflateEncoder}; use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt, prelude::*}; use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; use std::{io, io::Read, io::Write}; -use tarpc::context::Context; use tarpc::{ client, context, serde_transport::tcp, diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index 80f8a03be..1548d8e80 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -3,10 +3,9 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] -use console_subscriber::Server; use futures::prelude::*; -use tarpc::context::Context; use tarpc::{context, serde_transport as transport}; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 3cf95b27e..f33e38833 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -3,6 +3,7 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] /// - The PubSub server sets up TCP listeners on 2 ports, the "subscriber" port and the "publisher" /// port. Because both publishers and subscribers initiate their connections to the PubSub @@ -41,8 +42,6 @@ use futures::{ use opentelemetry::trace::TracerProvider as _; use publisher::Publisher as _; use serde::de::DeserializeOwned; -use serde::{Deserialize, Serialize}; -use std::ops::Shl; use std::{ collections::HashMap, error::Error, @@ -50,10 +49,11 @@ use std::{ net::SocketAddr, sync::{Arc, Mutex, RwLock}, }; +use serde::Serialize; use subscriber::Subscriber as _; use tarpc::context::{ExtractContext}; use tarpc::{ - ClientMessage, client, context, + client, context, serde_transport::tcp, server::{self, Channel}, tokio_serde::formats::Json, diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index ff0307d39..f2a98cc87 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -3,11 +3,11 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] use futures::prelude::*; -use tarpc::context::Context; use tarpc::{ - ClientMessage, client, context, + client, context, server::{self, Channel}, transport, }; diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index 07eb6a8ef..b970a6dac 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -3,6 +3,7 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] use futures::prelude::*; use rustls_pemfile::certs; @@ -10,7 +11,6 @@ use std::io::{self, BufReader, Cursor}; use std::net::{IpAddr, Ipv4Addr}; use std::sync::Arc; -use tarpc::context::Context; use tarpc::{context, serde_transport as transport}; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index abf1cbbc1..b67d98fe4 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -3,7 +3,7 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. - +#![deny(warnings, unused, dead_code)] #![allow(clippy::type_complexity)] use crate::{ From 044629d651d7b931120890206867521efa54d996 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Wed, 26 Nov 2025 17:06:39 +0100 Subject: [PATCH 19/26] cleanup --- example-service/src/client.rs | 4 +- plugins/tests/service.rs | 8 +-- tarpc/Cargo.toml | 2 - tarpc/examples/compression.rs | 11 ++-- tarpc/examples/custom_transport.rs | 6 ++- tarpc/examples/pubsub.rs | 50 ++++-------------- tarpc/examples/readme.rs | 12 ++--- tarpc/examples/tls_over_tcp.rs | 12 +++-- tarpc/examples/tracing.rs | 19 ++----- tarpc/src/client.rs | 14 ++--- tarpc/src/client/in_flight_requests.rs | 15 +++--- tarpc/src/client/stub.rs | 25 ++++----- tarpc/src/client/stub/load_balance.rs | 6 +-- tarpc/src/context.rs | 5 ++ tarpc/src/lib.rs | 2 +- tarpc/src/server.rs | 52 +++++++++---------- tarpc/src/server/incoming.rs | 2 +- .../src/server/limits/requests_per_channel.rs | 2 +- tarpc/src/server/request_hook.rs | 6 +-- tarpc/src/server/request_hook/before.rs | 4 +- tarpc/src/server/testing.rs | 4 -- tarpc/src/transport/channel.rs | 4 +- tarpc/tests/dataservice.rs | 3 +- tarpc/tests/service_functional.rs | 30 +++++------ 24 files changed, 118 insertions(+), 180 deletions(-) diff --git a/example-service/src/client.rs b/example-service/src/client.rs index e1d496f59..6f3930343 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -35,8 +35,8 @@ async fn main() -> anyhow::Result<()> { let client = WorldClient::new(client::Config::default(), transport.await?).spawn(); let hello = async move { - let mut context = context::Context::current(); - let mut context2 = context::Context::current(); + let mut context = context::current(); + let mut context2 = context::current(); // Send the request twice, just to be safe! ;) tokio::select! { diff --git a/plugins/tests/service.rs b/plugins/tests/service.rs index 3bd8b4c4e..2e450095c 100644 --- a/plugins/tests/service.rs +++ b/plugins/tests/service.rs @@ -1,7 +1,6 @@ use serde::{Deserialize, Serialize}; use std::hash::Hash; use tarpc::context; -use tarpc::context::Context; #[test] fn att_service_trait() { @@ -14,12 +13,7 @@ fn att_service_trait() { impl Foo for () { type Context = context::Context; - async fn two_part( - self, - _: &mut context::Context, - s: String, - i: i32, - ) -> (String, i32) { + async fn two_part(self, _: &mut Self::Context, s: String, i: i32) -> (String, i32) { (s, i) } diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index 0a5efc137..778eb0938 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -61,8 +61,6 @@ tracing = { version = "0.1", default-features = false, features = [ tracing-opentelemetry = { version = "0.31.0", default-features = false } opentelemetry = { version = "0.30.0", default-features = false } opentelemetry-semantic-conventions = "0.30.0" -anymap3 = "1.0.1" -serde-value = "0.7" [dev-dependencies] assert_matches = "1.4" diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index aa12147e5..c96014eea 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -122,26 +122,21 @@ async fn spawn(fut: impl Future + Send + 'static) { #[tokio::main] async fn main() -> anyhow::Result<()> { let mut incoming = tcp::listen("localhost:0", Bincode::default).await?; - let addr = incoming.local_addr(); tokio::spawn(async move { let transport = incoming.next().await.unwrap().unwrap(); - let transport = add_compression(transport); - BaseChannel::with_defaults(transport) + BaseChannel::with_defaults(add_compression(transport)) .execute(HelloServer.serve()) .for_each(spawn) .await; }); let transport = tcp::connect(addr, Bincode::default).await?; - let transport = add_compression(transport); - let client = WorldClient::new(client::Config::default(), transport).spawn(); + let client = WorldClient::new(client::Config::default(), add_compression(transport)).spawn(); println!( "{}", - client - .hello(&mut context::Context::current(), "friend".into()) - .await? + client.hello(&mut context::current(), "friend".into()).await? ); Ok(()) } diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index 1548d8e80..7fe32bfa7 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -6,7 +6,8 @@ #![deny(warnings, unused, dead_code)] use futures::prelude::*; -use tarpc::{context, serde_transport as transport}; +use tarpc::{context}; +use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; @@ -24,6 +25,7 @@ impl PingService for Service { type Context = context::Context; async fn ping(self, _: &mut Self::Context) {} } + #[tokio::main] async fn main() -> anyhow::Result<()> { let bind_addr = "/tmp/tarpc_on_unix_example.sock"; @@ -52,7 +54,7 @@ async fn main() -> anyhow::Result<()> { let transport = transport::new(codec_builder.new_framed(conn), Bincode::default()); PingServiceClient::new(Default::default(), transport) .spawn() - .ping(&mut context::Context::current()) + .ping(&mut context::current()) .await?; Ok(()) diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index f33e38833..70b41fdb3 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -162,6 +162,7 @@ async fn spawn(fut: impl Future + Send + 'static) { tokio::spawn(fut); } +// TODO: Remove serde bounds here impl Publisher where ClientCtx: ExtractContext @@ -172,7 +173,6 @@ where + Sync + 'static, { - // TODO: Remove serde bounds here async fn start(self) -> io::Result { let mut connecting_publishers = tcp::listen("localhost:0", Json::default).await?; @@ -234,9 +234,7 @@ where subscriber: subscriber::SubscriberClient, ) { // Populate the topics - if let Ok(topics) = subscriber - .topics(&mut ClientCtx::from(context::Context::current())) - .await + if let Ok(topics) = subscriber.topics(&mut ClientCtx::from(context::current())).await { self.clients.lock().unwrap().insert( subscriber_addr, @@ -301,16 +299,10 @@ where Some(subscriptions) => subscriptions.clone(), }; let mut publications = Vec::new(); - for client in subscribers.values_mut() { publications.push(async { - client - .receive( - &mut ClientCtx::from(context::Context::current()), - topic.clone(), - message.clone(), - ) - .await + let mut context = ClientCtx::from(context::current()); + client.receive(&mut context, topic.clone(), message.clone(), ).await }); } // Ignore failing subscribers. In a real pubsub, you'd want to continually retry until @@ -366,14 +358,12 @@ async fn main() -> anyhow::Result<()> { let _subscriber0 = Subscriber::connect( addrs.subscriptions, vec!["calculus".into(), "cool shorts".into()], - ) - .await?; + ).await?; let _subscriber1 = Subscriber::connect( addrs.subscriptions, vec!["cool shorts".into(), "history".into()], - ) - .await?; + ).await?; let publisher = publisher::PublisherClient::new( client::Config::default(), @@ -382,38 +372,18 @@ async fn main() -> anyhow::Result<()> { .spawn(); publisher - .publish( - &mut context::Context::current(), - "calculus".into(), - "sqrt(2)".into(), - ) - .await?; + .publish(&mut context::current(), "calculus".into(), "sqrt(2)".into()).await?; publisher - .publish( - &mut context::Context::current(), - "cool shorts".into(), - "hello to all".into(), - ) - .await?; + .publish(&mut context::current(), "cool shorts".into(), "hello to all".into()).await?; publisher - .publish( - &mut context::Context::current(), - "history".into(), - "napoleon".to_string(), - ) - .await?; + .publish(&mut context::current(), "history".into(), "napoleon".to_string()).await?; drop(_subscriber0); publisher - .publish( - &mut context::Context::current(), - "cool shorts".into(), - "hello to who?".into(), - ) - .await?; + .publish(&mut context::current(), "cool shorts".into(), "hello to who?".into(), ).await?; tracer_provider.shutdown()?; info!("done."); diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index f2a98cc87..f8f298921 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -6,11 +6,7 @@ #![deny(warnings, unused, dead_code)] use futures::prelude::*; -use tarpc::{ - client, context, - server::{self, Channel}, - transport, -}; +use tarpc::{client, context, server::{self, Channel}}; /// This is the service definition. It looks a lot like a trait definition. /// It defines one RPC, hello, which takes one arg, name, and returns a String. @@ -37,7 +33,7 @@ async fn spawn(fut: impl Future + Send + 'static) { #[tokio::main] async fn main() -> anyhow::Result<()> { - let (client_transport, server_transport) = transport::channel::unbounded(); + let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); let server = server::BaseChannel::with_defaults(server_transport); tokio::spawn(server.execute(HelloServer.serve()).for_each(spawn)); @@ -49,9 +45,7 @@ async fn main() -> anyhow::Result<()> { // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context // specifies a deadline and trace information which can be helpful in debugging requests. - let hello = client - .hello(&mut context::Context::current(), "Stim".to_string()) - .await?; + let hello = client.hello(&mut context::current(), "Stim".to_string()).await?; println!("{hello}"); diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index b970a6dac..0ba8f2581 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -11,10 +11,6 @@ use std::io::{self, BufReader, Cursor}; use std::net::{IpAddr, Ipv4Addr}; use std::sync::Arc; -use tarpc::{context, serde_transport as transport}; -use tarpc::server::{BaseChannel, Channel}; -use tarpc::tokio_serde::formats::Bincode; -use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; use tokio::net::TcpListener; use tokio::net::TcpStream; use tokio_rustls::rustls::{ @@ -23,6 +19,12 @@ use tokio_rustls::rustls::{ }; use tokio_rustls::{TlsAcceptor, TlsConnector}; +use tarpc::context; +use tarpc::serde_transport as transport; +use tarpc::server::{BaseChannel, Channel}; +use tarpc::tokio_serde::formats::Bincode; +use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; + #[tarpc::service] pub trait PingService { async fn ping() -> String; @@ -146,7 +148,7 @@ async fn main() -> anyhow::Result<()> { let transport = transport::new(codec_builder.new_framed(stream), Bincode::default()); let answer = PingServiceClient::new(Default::default(), transport) .spawn() - .ping(&mut context::Context::current()) + .ping(&mut context::current()) .await?; println!("ping answer: {answer}"); diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index b67d98fe4..f36db524e 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -78,11 +78,7 @@ where type Context = context::Context; async fn double(self, _: &mut Self::Context, x: i32) -> Result { self.add_client - .add( - &mut ClientCtx::from(context::Context::current()), - x, - x, - ) + .add(&mut ClientCtx::from(context::current()), x, x) .await .map_err(|e| e.to_string()) } @@ -134,10 +130,7 @@ where } fn make_stub( - backends: [impl Transport>, Response> - + Send - + Sync - + 'static; N], + backends: [impl Transport>, Response> + Send + Sync + 'static; N], ) -> retry::Retry< impl Fn(&Result, u32) -> bool + Clone, load_balance::RoundRobin, Resp, ClientCtx>>, @@ -200,11 +193,7 @@ async fn main() -> anyhow::Result<()> { .filter_map(|r| future::ready(r.ok())); let addr = double_listener.get_ref().local_addr(); let double_server = double_listener.map(BaseChannel::with_defaults).take(1); - let server = DoubleServer::<_, context::Context> { - add_client, - ghost: PhantomData, - } - .serve(); + let server = DoubleServer::<_, context::Context> { add_client, ghost: PhantomData }.serve(); tokio::spawn(spawn_incoming(double_server.execute(server))); let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; @@ -215,7 +204,7 @@ async fn main() -> anyhow::Result<()> { tracing::info!( "{:?}", double_client - .double(&mut context::Context::current(), 1) + .double(&mut context::current(), 1) .await? ); } diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index fab3b2548..90f7cac45 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -743,7 +743,7 @@ mod tests { let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); - let context = context::Context::current(); + let context = context::current(); dispatch .in_flight_requests @@ -758,7 +758,7 @@ mod tests { server_channel .send(Response { request_id: 0, - context: context::Context::current(), + context: context::current(), message: Ok("Resp".into()), }) .await @@ -786,7 +786,7 @@ mod tests { async fn dispatch_response_doesnt_cancel_after_complete() { let (cancellation, mut canceled_requests) = cancellations(); let (tx, mut response) = oneshot::channel(); - tx.send(Ok((context::Context::current(), "well done"))) + tx.send(Ok((context::current(), "well done"))) .unwrap(); // resp's drop() is run, but should not send a cancel message. ResponseGuard { @@ -835,7 +835,7 @@ mod tests { &mut server_channel, Response { request_id: 0, - context: context::Context::current(), + context: context::current(), message: Ok("hello".into()), }, ) @@ -921,7 +921,7 @@ mod tests { drop(dispatch); // error on send let resp = channel - .call(&mut context::Context::current(), "hi".to_string()) + .call(&mut context::current(), "hi".to_string()) .await; assert_matches!(resp, Err(RpcError::Shutdown)); } @@ -1141,7 +1141,7 @@ mod tests { let request_id = u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); let request = DispatchRequest { - ctx: context::Context::current(), + ctx: context::current(), span: Span::current(), request_id, request: request.to_string(), @@ -1167,7 +1167,7 @@ mod tests { let request_id = u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); let request = DispatchRequest { - ctx: context::Context::current(), + ctx: context::current(), span: Span::current(), request_id, request: request.to_string(), diff --git a/tarpc/src/client/in_flight_requests.rs b/tarpc/src/client/in_flight_requests.rs index cc5091fc6..d6424c564 100644 --- a/tarpc/src/client/in_flight_requests.rs +++ b/tarpc/src/client/in_flight_requests.rs @@ -1,14 +1,17 @@ -use crate::client::RpcError; -use crate::{context, trace, util::{Compact, TimeUntil}}; +use crate::{ + context, trace, + util::{Compact, TimeUntil} +}; use fnv::FnvHashMap; -use std::time::Instant; use std::{ collections::hash_map, task::{Context, Poll}, }; +use std::time::Instant; use tokio::sync::oneshot; use tokio_util::time::delay_queue::{self, DelayQueue}; use tracing::Span; +use crate::client::RpcError; /// Requests already written to the wire that haven't yet received responses. #[derive(Debug)] @@ -77,11 +80,7 @@ impl InFlightRequests { } /// Removes a request without aborting. Returns true if the request was found. - pub fn complete_request( - &mut self, - request_id: u64, - result: Result<(context::Context, Res), RpcError>, - ) -> Option { + pub fn complete_request(&mut self, request_id: u64, result: Result<(context::Context, Res), RpcError>) -> Option { if let Some(request_data) = self.request_data.remove(&request_id) { self.request_data.compact(0.1); self.deadlines.remove(&request_data.deadline_key); diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index 1a0dbfff6..2aa6908e3 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -1,7 +1,12 @@ //! Provides a Stub trait, implemented by types that can call remote services. -use crate::context::{ExtractContext}; -use crate::{RequestName, client::{Channel, RpcError}, server::Serve, context}; +use crate::{ + RequestName, + client::{Channel, RpcError}, + context, + context::ExtractContext, + server::Serve, +}; pub mod load_balance; pub mod retry; @@ -23,11 +28,7 @@ pub trait Stub { type ClientCtx; /// Calls a remote service. - async fn call( - &self, - ctx: &mut Self::ClientCtx, - request: Self::Req, - ) -> Result; + async fn call(&self, ctx: &mut Self::ClientCtx, request: Self::Req) -> Result; } impl Stub for Channel @@ -46,26 +47,22 @@ where impl Stub for S where - S: Serve + Clone, + S: Serve + Clone, { type Req = S::Req; type Resp = S::Resp; - type ClientCtx = context::Context; + type ClientCtx = S::ServerCtx; async fn call( &self, ctx: &mut Self::ClientCtx, req: Self::Req, ) -> Result { - let mut server_ctx = ctx.clone(); - let res = self .clone() - .serve(&mut server_ctx, req) + .serve(ctx, req) .await .map_err(RpcError::Server); - *ctx = server_ctx; - res } } diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs index 4b9d9df3a..43c1c8b23 100644 --- a/tarpc/src/client/stub/load_balance.rs +++ b/tarpc/src/client/stub/load_balance.rs @@ -197,17 +197,17 @@ mod consistent_hash { for _ in 0..2 { let resp = stub - .call(&mut context::Context::current(), 'a') + .call(&mut context::current(), 'a') .await?; assert_eq!(resp, 1); let resp = stub - .call(&mut context::Context::current(), 'b') + .call(&mut context::current(), 'b') .await?; assert_eq!(resp, 2); let resp = stub - .call(&mut context::Context::current(), 'c') + .call(&mut context::current(), 'c') .await?; assert_eq!(resp, 3); } diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index 6db79e49b..423084c61 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -148,6 +148,11 @@ impl Context { } } +///TODO: Document +pub fn current() -> Context { + Context::current() +} + /// An extension trait for [`tracing::Span`] for propagating tarpc Contexts. pub(crate) trait SpanExt { /// Sets the given context on this span. Newly-created spans will be children of the given diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index cb5a64085..e34722b6d 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -188,7 +188,7 @@ //! // The client has an RPC method for each RPC defined in the annotated trait. It takes the same //! // args as defined, with the addition of a Context, which is always the first arg. The Context //! // specifies a deadline and trace information which can be helpful in debugging requests. -//! let mut context = context::Context::current(); +//! let mut context = context::current(); //! let hello = client.hello(&mut context, "Stim".to_string()).await?; //! //! println!("{hello}"); diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index d1a384e32..a7560132b 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -388,7 +388,7 @@ where /// tokio::spawn(request.execute(serve(|_, i| async move { Ok(i + 1) }.boxed()))); /// } /// }); - /// let mut context = context::Context::current(); + /// let mut context = context::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` @@ -429,7 +429,7 @@ where /// .for_each(|response| async move { /// tokio::spawn(response); /// }.boxed())); - /// let mut context = context::Context::current(); + /// let mut context = context::current(); /// assert_eq!( /// client.call(&mut context, 1).await.unwrap(), /// 2); @@ -787,7 +787,7 @@ where /// .for_each(|response| async move { /// tokio::spawn(response); /// }.boxed())); - /// let mut context = context::Context::current(); + /// let mut context = context::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` @@ -895,7 +895,7 @@ impl InFlightRequest { /// in_flight_request.execute(serve(|_, i| async move { Ok(i + 1) }.boxed())).await; /// } /// }); - /// let mut context = context::Context::current(); + /// let mut context = context::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` @@ -1092,7 +1092,7 @@ mod tests { fn fake_request(req: Req) -> ClientMessage { ClientMessage::Request(Request { - context: context::Context::current(), + context: context::current(), id: 0, message: req, }) @@ -1108,7 +1108,7 @@ mod tests { async fn test_serve() { let serve = serve(|_, i| async move { Ok(i) }.boxed()); assert_matches!( - serve.serve(&mut context::Context::current(), 7).await, + serve.serve(&mut context::current(), 7).await, Ok(7) ); } @@ -1139,7 +1139,7 @@ mod tests { .boxed() }); let deadline_hook = serve.before(SetDeadline(some_time)); - let mut ctx = context::Context::current(); + let mut ctx = context::current(); ctx.deadline = some_other_time; deadline_hook.serve(&mut ctx, 7).await?; Ok(()) @@ -1174,7 +1174,7 @@ mod tests { let serve = serve(move |_: &mut context::Context, i| async move { Ok(i) }.boxed()); serve .before_and_after(PrintLatency::new()) - .serve(&mut context::Context::current(), 7) + .serve(&mut context::current(), 7) .await?; Ok(()) } @@ -1186,7 +1186,7 @@ mod tests { Err(ServerError::new(io::ErrorKind::Other, "oops".into())) }); let resp: Result = deadline_hook - .serve(&mut context::Context::current(), 7) + .serve(&mut context::current(), 7) .await; assert_matches!(resp, Err(_)); Ok(()) @@ -1200,14 +1200,14 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::Context::current(), + context: context::current(), message: (), }) .unwrap(); assert_matches!( channel.as_mut().start_request(Request { id: 0, - context: context::Context::current(), + context: context::current(), message: () }), Err(AlreadyExistsError) @@ -1223,7 +1223,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::Context::current(), + context: context::current(), message: (), }) .unwrap(); @@ -1231,7 +1231,7 @@ mod tests { .as_mut() .start_request(Request { id: 1, - context: context::Context::current(), + context: context::current(), message: (), }) .unwrap(); @@ -1254,7 +1254,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::Context::current(), + context: context::current(), message: (), }) .unwrap(); @@ -1283,7 +1283,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::Context::current(), + context: context::current(), message: (), }) .unwrap(); @@ -1325,7 +1325,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::Context::current(), + context: context::current(), message: (), }) .unwrap(); @@ -1348,7 +1348,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::Context::current(), + context: context::current(), message: (), }) .unwrap(); @@ -1357,7 +1357,7 @@ mod tests { .as_mut() .start_send(Response { request_id: 0, - context: context::Context::current(), + context: context::current(), message: Ok(()), }) .unwrap(); @@ -1416,7 +1416,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 0, - context: context::Context::current(), + context: context::current(), message: (), }) .unwrap(); @@ -1425,7 +1425,7 @@ mod tests { .channel_pin_mut() .start_send(Response { request_id: 0, - context: context::Context::current(), + context: context::current(), message: Ok(()), }) .unwrap(); @@ -1437,7 +1437,7 @@ mod tests { .responses_tx .send(Response { request_id: 1, - context: context::Context::current(), + context: context::current(), message: Ok(()), }) .await @@ -1448,7 +1448,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 1, - context: context::Context::current(), + context: context::current(), message: (), }) .unwrap(); @@ -1469,7 +1469,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 0, - context: context::Context::current(), + context: context::current(), message: (), }) .unwrap(); @@ -1478,7 +1478,7 @@ mod tests { .channel_pin_mut() .start_send(Response { request_id: 0, - context: context::Context::current(), + context: context::current(), message: Ok(()), }) .unwrap(); @@ -1489,7 +1489,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 1, - context: context::Context::current(), + context: context::current(), message: (), }) .unwrap(); @@ -1499,7 +1499,7 @@ mod tests { .responses_tx .send(Response { request_id: 1, - context: context::Context::current(), + context: context::current(), message: Ok(()), }) .await diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index 67d46e330..36e942f62 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -66,7 +66,7 @@ where /// BaseChannel::new(server::Config::default(), rx) /// }).execute(serve(|_, i| async move { Ok(i + 1) }.boxed())); /// tokio::spawn(spawn_incoming(incoming)); -/// let mut context = context::Context::current(); +/// let mut context = context::current(); /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` diff --git a/tarpc/src/server/limits/requests_per_channel.rs b/tarpc/src/server/limits/requests_per_channel.rs index 32b126aa6..4c7c8dbcc 100644 --- a/tarpc/src/server/limits/requests_per_channel.rs +++ b/tarpc/src/server/limits/requests_per_channel.rs @@ -337,7 +337,7 @@ mod tests { .as_mut() .start_send(Response { request_id: 0, - context: context::Context::current(), + context: context::current(), message: Ok(1), }) .unwrap(); diff --git a/tarpc/src/server/request_hook.rs b/tarpc/src/server/request_hook.rs index 090c4a72c..cce5998ee 100644 --- a/tarpc/src/server/request_hook.rs +++ b/tarpc/src/server/request_hook.rs @@ -58,7 +58,7 @@ pub trait RequestHook: Serve { /// Ok(()) /// }) /// }); - /// let mut context = context::Context::current(); + /// let mut context = context::current(); /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` @@ -101,7 +101,7 @@ pub trait RequestHook: Serve { /// } /// future::ready(()) /// }); - /// let mut context = context::Context::current(); + /// let mut context = context::current(); /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` @@ -153,7 +153,7 @@ pub trait RequestHook: Serve { /// let serve = serve(|_ctx, i| async move { /// Ok(i + 1) /// }.boxed()).before_and_after(PrintLatency(Instant::now())); - /// let mut context = context::Context::current(); + /// let mut context = context::current(); /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_ok()); /// ``` diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs index df4873e83..adfac8e79 100644 --- a/tarpc/src/server/request_hook/before.rs +++ b/tarpc/src/server/request_hook/before.rs @@ -129,7 +129,7 @@ where /// Ok(()) /// }) /// .serving(serve(|_ctx, i| async move { Ok(i + 1) }.boxed())); -/// let mut context = context::Context::current(); +/// let mut context = context::current(); /// let response = serve.clone().serve(&mut context, 1); /// assert!(block_on(response).is_ok()); /// assert!(i.get() == 2); @@ -219,7 +219,7 @@ fn before_request_list() { Ok(()) }) .serving(serve(|_ctx, i| async move { Ok(i + 1) }.boxed())); - let mut context = crate::context::Context::current(); + let mut context = crate::context::current(); let response = serve.clone().serve(&mut context, 1); assert!(block_on(response).is_ok()); assert!(i.get() == 2); diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index 047464cf5..76df940ce 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -125,10 +125,6 @@ impl FakeChannel<(), ()> { { let (request_cancellation, canceled_requests) = cancellations(); - let mut x = anymap3::AnyMap::new(); - - x.entry::<&str>(); - FakeChannel { stream: Default::default(), sink: Default::default(), diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 1ff75e70d..4a4e216c0 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -211,10 +211,10 @@ mod tests { let client = client::new(client::Config::default(), client_channel).spawn(); let response1 = client - .call(&mut context::Context::current(), "123".into()) + .call(&mut context::current(), "123".into()) .await; let response2 = client - .call(&mut context::Context::current(), "abc".into()) + .call(&mut context::current(), "abc".into()) .await; trace!("response1: {:?}, response2: {:?}", response1, response2); diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index a2e458361..5a5b2f8e7 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -1,5 +1,4 @@ use futures::prelude::*; -use tarpc::context::Context; use tarpc::serde_transport; use tarpc::{ client, context, @@ -55,7 +54,7 @@ async fn test_call() -> anyhow::Result<()> { let client = ColorProtocolClient::new(client::Config::default(), transport).spawn(); let color = client - .get_opposite_color(&mut context::Context::current(), TestData::White) + .get_opposite_color(&mut context::current(), TestData::White) .await?; assert_eq!(color, TestData::Black); diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index abe1ba0df..e716437c7 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -4,9 +4,7 @@ use futures::{ prelude::*, }; use std::time::{Duration, Instant}; -use tarpc::context::Context; use tarpc::{ - ClientMessage, client::{self}, context, server::{BaseChannel, Channel, incoming::Incoming}, @@ -50,7 +48,7 @@ async fn sequential() { ); assert_eq!( client - .call(&mut context::Context::current(), 1) + .call(&mut context::current(), 1) .await .unwrap(), 2 @@ -85,7 +83,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { tokio::spawn(async move { let client = LoopClient::new(client::Config::default(), tx).spawn(); - let mut ctx = context::Context::current(); + let mut ctx = context::current(); ctx.deadline = Instant::now() + Duration::from_secs(60 * 60); let _ = client.r#loop(&mut ctx).await; }); @@ -128,12 +126,12 @@ async fn serde_tcp() -> anyhow::Result<()> { assert_matches!( client - .add(&mut context::Context::current(), 1, 2) + .add(&mut context::current(), 1, 2) .await, Ok(3) ); assert_matches!( - client.hey(&mut context::Context::current(), "Tim".to_string()).await, + client.hey(&mut context::current(), "Tim".to_string()).await, Ok(ref s) if s == "Hey, Tim." ); @@ -166,10 +164,10 @@ async fn serde_uds() -> anyhow::Result<()> { // Save results using socket so we can clean the socket even if our test assertions fail let res1 = client - .add(&mut context::Context::current(), 1, 2) + .add(&mut context::current(), 1, 2) .await; let res2 = client - .hey(&mut context::Context::current(), "Tim".to_string()) + .hey(&mut context::current(), "Tim".to_string()) .await; assert_matches!(res1, Ok(3)); @@ -194,7 +192,7 @@ async fn concurrent() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let mut context = context::Context::current(); + let mut context = context::current(); let req1 = client.add(&mut context, 1, 2); assert_matches!(req1.await, Ok(3)); @@ -224,9 +222,9 @@ async fn concurrent_join() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let mut context1 = context::Context::current(); - let mut context2 = context::Context::current(); - let mut context3 = context::Context::current(); + let mut context1 = context::current(); + let mut context2 = context::current(); + let mut context3 = context::current(); let req1 = client.add(&mut context1, 1, 2); let req2 = client.add(&mut context2, 3, 4); @@ -258,8 +256,8 @@ async fn concurrent_join_all() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let mut context1 = context::Context::current(); - let mut context2 = context::Context::current(); + let mut context1 = context::current(); + let mut context2 = context::current(); let req1 = client.add(&mut context1, 1, 2); let req2 = client.add(&mut context2, 3, 4); @@ -301,11 +299,11 @@ async fn counter() -> anyhow::Result<()> { let client = CounterClient::new(client::Config::default(), tx).spawn(); assert_matches!( - client.count(&mut context::Context::current()).await, + client.count(&mut context::current()).await, Ok(1) ); assert_matches!( - client.count(&mut context::Context::current()).await, + client.count(&mut context::current()).await, Ok(2) ); From b1120173085399d4db1a9f4ea37116a02a754d84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Wed, 26 Nov 2025 17:11:43 +0100 Subject: [PATCH 20/26] more --- tarpc/examples/pubsub.rs | 18 ++++++++++++------ tarpc/examples/tracing.rs | 7 +------ tarpc/src/client/stub.rs | 17 ++++------------- 3 files changed, 17 insertions(+), 25 deletions(-) diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 70b41fdb3..6c0099a97 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -358,12 +358,14 @@ async fn main() -> anyhow::Result<()> { let _subscriber0 = Subscriber::connect( addrs.subscriptions, vec!["calculus".into(), "cool shorts".into()], - ).await?; + ) + .await?; let _subscriber1 = Subscriber::connect( addrs.subscriptions, vec!["cool shorts".into(), "history".into()], - ).await?; + ) + .await?; let publisher = publisher::PublisherClient::new( client::Config::default(), @@ -372,18 +374,22 @@ async fn main() -> anyhow::Result<()> { .spawn(); publisher - .publish(&mut context::current(), "calculus".into(), "sqrt(2)".into()).await?; + .publish(&mut context::current(), "calculus".into(), "sqrt(2)".into()) + .await?; publisher - .publish(&mut context::current(), "cool shorts".into(), "hello to all".into()).await?; + .publish(&mut context::current(), "cool shorts".into(), "hello to all".into()) + .await?; publisher - .publish(&mut context::current(), "history".into(), "napoleon".to_string()).await?; + .publish(&mut context::current(), "history".into(), "napoleon".to_string()) + .await?; drop(_subscriber0); publisher - .publish(&mut context::current(), "cool shorts".into(), "hello to who?".into(), ).await?; + .publish(&mut context::current(), "cool shorts".into(), "hello to who?".into(), ) + .await?; tracer_provider.shutdown()?; info!("done."); diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index f36db524e..0789d0a43 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -201,12 +201,7 @@ async fn main() -> anyhow::Result<()> { double::DoubleClient::new(client::Config::default(), to_double_server).spawn(); for _ in 1..=5 { - tracing::info!( - "{:?}", - double_client - .double(&mut context::current(), 1) - .await? - ); + tracing::info!("{:?}", double_client.double(&mut context::current(), 1).await?); } tracer_provider.shutdown()?; diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index 2aa6908e3..5e473566c 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -28,7 +28,8 @@ pub trait Stub { type ClientCtx; /// Calls a remote service. - async fn call(&self, ctx: &mut Self::ClientCtx, request: Self::Req) -> Result; + async fn call(&self, ctx: &mut Self::ClientCtx, request: Self::Req) + -> Result; } impl Stub for Channel @@ -52,17 +53,7 @@ where type Req = S::Req; type Resp = S::Resp; type ClientCtx = S::ServerCtx; - async fn call( - &self, - ctx: &mut Self::ClientCtx, - req: Self::Req, - ) -> Result { - let res = self - .clone() - .serve(ctx, req) - .await - .map_err(RpcError::Server); - - res + async fn call(&self, ctx: &mut Self::ClientCtx, req: Self::Req) -> Result { + self.clone().serve(ctx, req).await.map_err(RpcError::Server) } } From b988a2d39edff521df47d86a1923b96dbc3d6252 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Wed, 26 Nov 2025 18:42:19 +0100 Subject: [PATCH 21/26] cleanup --- plugins/src/lib.rs | 2 +- tarpc/src/client/stub/load_balance.rs | 20 +++--- tarpc/src/context.rs | 11 +-- tarpc/src/lib.rs | 68 ++----------------- tarpc/src/server/incoming.rs | 3 +- .../src/server/limits/requests_per_channel.rs | 12 +--- tarpc/src/server/request_hook/before.rs | 12 ++-- .../server/request_hook/before_and_after.rs | 3 +- tarpc/src/server/testing.rs | 15 ++-- tarpc/src/transport/channel.rs | 17 ++--- tarpc/tests/service_functional.rs | 48 ++++--------- 11 files changed, 54 insertions(+), 157 deletions(-) diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 8e35ee49d..61a2e32a0 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -371,7 +371,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// # Example /// /// ```no_run -/// use tarpc::{client, transport, service, server::{self, Channel}, context::Context}; +/// use tarpc::{client, context, transport, service, server::{self, Channel}, context::Context}; /// /// #[service] /// pub trait Calculator { diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs index 43c1c8b23..eb605ecf9 100644 --- a/tarpc/src/client/stub/load_balance.rs +++ b/tarpc/src/client/stub/load_balance.rs @@ -5,7 +5,9 @@ pub use round_robin::RoundRobin; /// Provides a stub that load-balances with a simple round-robin strategy. mod round_robin { - use crate::client::{RpcError, stub}; + use crate::{ + client::{RpcError, stub}, + }; use cycle::AtomicCycle; impl stub::Stub for RoundRobin @@ -96,7 +98,9 @@ mod round_robin { /// Each request is hashed, then mapped to a stub based on the hash. Equivalent requests will use /// the same stub. mod consistent_hash { - use crate::client::{RpcError, stub}; + use crate::{ + client::{RpcError, stub} + }; use std::{ collections::hash_map::RandomState, hash::{BuildHasher, Hash}, @@ -196,19 +200,13 @@ mod consistent_hash { )?; for _ in 0..2 { - let resp = stub - .call(&mut context::current(), 'a') - .await?; + let resp = stub.call(&mut context::current(), 'a').await?; assert_eq!(resp, 1); - let resp = stub - .call(&mut context::current(), 'b') - .await?; + let resp = stub.call(&mut context::current(), 'b').await?; assert_eq!(resp, 2); - let resp = stub - .call(&mut context::current(), 'c') - .await?; + let resp = stub.call(&mut context::current(), 'c').await?; assert_eq!(resp, 3); } diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index 423084c61..a1b50c72e 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -111,12 +111,18 @@ mod absolute_to_relative_time { } } + assert_impl_all!(Context: Send, Sync); fn ten_seconds_from_now() -> Instant { Instant::now() + Duration::from_secs(10) } +/// Returns the context for the current request, or a default Context if no request is active. +pub fn current() -> Context { + Context::current() +} + #[derive(Clone)] struct Deadline(Instant); @@ -148,11 +154,6 @@ impl Context { } } -///TODO: Document -pub fn current() -> Context { - Context::current() -} - /// An extension trait for [`tracing::Span`] for propagating tarpc Contexts. pub(crate) trait SpanExt { /// Sets the given context on this span. Newly-created spans will be children of the given diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index e34722b6d..06385b15c 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -143,9 +143,7 @@ //! # prelude::*, //! # }; //! # use tarpc::{ -//! # ClientMessage, //! # client, context, -//! # transport::channel, //! # server::{self, Channel}, //! # }; //! # // This is the service definition. It looks a lot like a trait definition. @@ -161,6 +159,7 @@ //! # struct HelloServer; //! # impl World for HelloServer { //! # type Context = context::Context; +//! # //! # // Each defined rpc generates an async fn that serves the RPC //! # async fn hello(self, _: &mut Self::Context, name: String) -> String { //! # format!("Hello, {name}!") @@ -171,8 +170,7 @@ //! # #[cfg(feature = "tokio1")] //! #[tokio::main] //! async fn main() -> anyhow::Result<()> { -//! use futures::future::Shared; -//! let (client_transport, server_transport) = channel::unbounded(); +//! let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); //! let server = server::BaseChannel::with_defaults(server_transport); //! tokio::spawn( //! server.execute(HelloServer.serve()) @@ -255,8 +253,7 @@ pub(crate) mod util; pub use crate::transport::sealed::Transport; -use std::ops::Deref; -use std::{any::Any, error::Error, io, sync::Arc, time::Instant}; +use std::{any::Any, error::Error, io, sync::Arc}; /// A message from a client to a server. #[derive(Debug)] @@ -284,35 +281,8 @@ pub enum ClientMessage { }, } -impl ClientMessage { - /// Creates a new ClientMessage by mapping the context using the provided function. - pub fn map_context(self, f: F) -> ClientMessage - where - F: FnOnce(Ctx) -> Ctx2, - { - match self { - ClientMessage::Request(Request { - context, - id, - message, - }) => ClientMessage::Request(Request { - context: f(context), - id, - message, - }), - ClientMessage::Cancel { - trace_context, - request_id, - } => ClientMessage::Cancel { - trace_context, - request_id, - }, - } - } -} - /// A request from a client to a server. -#[derive(Debug)] +#[derive(Clone, Debug)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Request { /// Trace context, deadline, and other cross-cutting concerns. @@ -333,9 +303,7 @@ impl RequestName for Arc where Req: RequestName, { - fn name(&self) -> &str { - self.as_ref().name() - } + fn name(&self) -> &str { self.as_ref().name() } } impl RequestName for Box @@ -401,21 +369,6 @@ pub struct Response { /// The response body, or an error if the request failed. pub message: Result, } - -impl Response { - /// Creates a modified Response by mapping the context using the provided function. - pub fn map_context(self, f: F) -> Response - where - F: FnOnce(Ctx) -> Ctx2, - { - Response { - request_id: self.request_id, - context: f(self.context), - message: self.message, - } - } -} - /// An error indicating the server aborted the request early, e.g., due to request throttling. #[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)] #[error("{kind:?}: {detail}")] @@ -538,17 +491,6 @@ impl ServerError { Self { kind, detail } } } - -impl Request -where - Ctx: Deref, -{ - /// Returns the deadline for this request. - pub fn deadline(&self) -> &Instant { - &self.context.deadline - } -} - #[test] fn test_channel_any_casts() { use assert_matches::assert_matches; diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index 36e942f62..568ae4495 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -58,7 +58,8 @@ where /// /// #[tokio::main] /// async fn main() { -/// let (tx, rx) = transport::channel::unbounded(); +/// use tracing_subscriber::filter::FilterExt; +/// let (tx, rx) = transport::channel::unbounded(); /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); /// tokio::spawn(dispatch); /// diff --git a/tarpc/src/server/limits/requests_per_channel.rs b/tarpc/src/server/limits/requests_per_channel.rs index 4c7c8dbcc..527cb6f98 100644 --- a/tarpc/src/server/limits/requests_per_channel.rs +++ b/tarpc/src/server/limits/requests_per_channel.rs @@ -270,10 +270,8 @@ mod tests { ghost: PhantomData In>, } impl PendingSink<(), ()> { - pub fn default() -> PendingSink< - io::Result>, - Response, - > { + pub fn default() + -> PendingSink>, Response, > { PendingSink { ghost: PhantomData } } } @@ -299,11 +297,7 @@ mod tests { } } impl Channel - for PendingSink< - io::Result>, - Response, - > - { + for PendingSink>, Response> { type Req = Req; type Resp = Resp; type Transport = (); diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs index adfac8e79..13fc18509 100644 --- a/tarpc/src/server/request_hook/before.rs +++ b/tarpc/src/server/request_hook/before.rs @@ -80,11 +80,7 @@ impl Clone for HookThenServe HookThenServe { pub(crate) fn new(serve: Serv, hook: Hook) -> Self { - Self { - serve, - hook, - ghost: PhantomData, - } + Self { serve, hook, ghost: PhantomData } } } @@ -97,7 +93,11 @@ where type Req = Serv::Req; type Resp = Serv::Resp; - async fn serve(self, ctx: &mut ServerCtx, req: Self::Req) -> Result { + async fn serve( + self, + ctx: &mut ServerCtx, + req: Self::Req + ) -> Result { let HookThenServe { serve, mut hook, .. } = self; diff --git a/tarpc/src/server/request_hook/before_and_after.rs b/tarpc/src/server/request_hook/before_and_after.rs index f3653a513..934d82ad5 100644 --- a/tarpc/src/server/request_hook/before_and_after.rs +++ b/tarpc/src/server/request_hook/before_and_after.rs @@ -39,8 +39,7 @@ impl Clone } } -impl Serve - for HookThenServeThenHook +impl Serve for HookThenServeThenHook where Req: RequestName, Serv: Serve, diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index 76df940ce..a92b50fc2 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -38,8 +38,7 @@ where } } -impl Sink> - for FakeChannel> +impl Sink> for FakeChannel> { type Error = io::Error; @@ -47,10 +46,7 @@ impl Sink> self.project().sink.poll_ready(cx).map_err(|e| match e {}) } - fn start_send( - mut self: Pin<&mut Self>, - response: Response, - ) -> Result<(), Self::Error> { + fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { self.as_mut() .project() .in_flight_requests @@ -70,8 +66,7 @@ impl Sink> } } -impl Channel - for FakeChannel>, Response> +impl Channel for FakeChannel>, Response> where Req: Unpin, { @@ -93,8 +88,7 @@ where } } -impl - FakeChannel>, Response> +impl FakeChannel>, Response> { pub fn push_req(&mut self, id: u64, message: Req) { let (_, abort_registration) = futures::future::AbortHandle::new_pair(); @@ -124,7 +118,6 @@ impl FakeChannel<(), ()> { -> FakeChannel>, Response> { let (request_cancellation, canceled_requests) = cancellations(); - FakeChannel { stream: Default::default(), sink: Default::default(), diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 4a4e216c0..47f3e4928 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -188,21 +188,18 @@ mod tests { let _ = tracing_subscriber::fmt::try_init(); let (client_channel, server_channel) = transport::channel::unbounded(); - tokio::spawn( stream::once(future::ready(server_channel)) .map(BaseChannel::with_defaults) - .execute(serve(|_ctx: &mut context::Context, request: String| { - async move { + .execute(serve(|_ctx: &mut context::Context, request: String| async move { request.parse::().map_err(|_| { ServerError::new( io::ErrorKind::InvalidInput, format!("{request:?} is not an int"), ) }) - } - .boxed() - })) + }.boxed() + )) .for_each(|channel| async move { tokio::spawn(channel.for_each(|response| response)); }), @@ -210,12 +207,8 @@ mod tests { let client = client::new(client::Config::default(), client_channel).spawn(); - let response1 = client - .call(&mut context::current(), "123".into()) - .await; - let response2 = client - .call(&mut context::current(), "abc".into()) - .await; + let response1 = client.call(&mut context::current(), "123".into()).await; + let response2 = client.call(&mut context::current(), "abc".into()).await; trace!("response1: {:?}, response2: {:?}", response1, response2); diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index e716437c7..9d3c70e37 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -35,24 +35,16 @@ impl Service for Server { #[tokio::test] async fn sequential() { - let (tx, rx) = channel::unbounded(); + let (tx, rx) = tarpc::transport::channel::unbounded(); let client = client::new(client::Config::default(), tx).spawn(); let channel = BaseChannel::with_defaults(rx); tokio::spawn( channel - .execute(tarpc::server::serve(|_, i: u32| { - async move { Ok(i + 1) }.boxed() - })) + .execute(tarpc::server::serve(|_, i: u32| async move { Ok(i + 1) }.boxed())) .for_each(|response| response), ); - assert_eq!( - client - .call(&mut context::current(), 1) - .await - .unwrap(), - 2 - ); + assert_eq!(client.call(&mut context::current(), 1).await.unwrap(), 2); } #[tokio::test] @@ -76,7 +68,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = transport::channel::unbounded(); + let (tx, rx) = tarpc::transport::channel::unbounded(); // Set up a client that initiates a long-lived request. // The request will complete in error when the server drops the connection. @@ -124,12 +116,7 @@ async fn serde_tcp() -> anyhow::Result<()> { let transport = serde_transport::tcp::connect(addr, Json::default).await?; let client = ServiceClient::new(client::Config::default(), transport).spawn(); - assert_matches!( - client - .add(&mut context::current(), 1, 2) - .await, - Ok(3) - ); + assert_matches!(client.add(&mut context::current(), 1, 2).await, Ok(3)); assert_matches!( client.hey(&mut context::current(), "Tim".to_string()).await, Ok(ref s) if s == "Hey, Tim." @@ -159,16 +146,11 @@ async fn serde_uds() -> anyhow::Result<()> { ); let transport = serde_transport::unix::connect(&sock, Json::default).await?; - let client = ServiceClient::new(client::Config::default(), transport).spawn(); // Save results using socket so we can clean the socket even if our test assertions fail - let res1 = client - .add(&mut context::current(), 1, 2) - .await; - let res2 = client - .hey(&mut context::current(), "Tim".to_string()) - .await; + let res1 = client.add(&mut context::current(), 1, 2).await; + let res2 = client.hey(&mut context::current(), "Tim".to_string()).await; assert_matches!(res1, Ok(3)); assert_matches!(res2, Ok(ref s) if s == "Hey, Tim."); @@ -180,7 +162,7 @@ async fn serde_uds() -> anyhow::Result<()> { async fn concurrent() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = transport::channel::unbounded(); + let (tx, rx) = tarpc::transport::channel::unbounded(); tokio::spawn( stream::once(ready(rx)) @@ -210,7 +192,7 @@ async fn concurrent() -> anyhow::Result<()> { async fn concurrent_join() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = transport::channel::unbounded(); + let (tx, rx) = tarpc::transport::channel::unbounded(); tokio::spawn( stream::once(ready(rx)) @@ -247,7 +229,7 @@ async fn spawn(fut: impl Future + Send + 'static) { async fn concurrent_join_all() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = transport::channel::unbounded(); + let (tx, rx) = tarpc::transport::channel::unbounded(); tokio::spawn( BaseChannel::with_defaults(rx) .execute(Server.serve()) @@ -298,14 +280,8 @@ async fn counter() -> anyhow::Result<()> { }); let client = CounterClient::new(client::Config::default(), tx).spawn(); - assert_matches!( - client.count(&mut context::current()).await, - Ok(1) - ); - assert_matches!( - client.count(&mut context::current()).await, - Ok(2) - ); + assert_matches!(client.count(&mut context::current()).await, Ok(1)); + assert_matches!(client.count(&mut context::current()).await, Ok(2)); Ok(()) } From 21e9223e494749eb2270e50f349abea6416bc4f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Wed, 26 Nov 2025 18:47:22 +0100 Subject: [PATCH 22/26] cleanup --- plugins/src/lib.rs | 27 +++++++++++++++------------ tarpc/src/context.rs | 1 - tarpc/src/lib.rs | 4 +++- tarpc/src/server/testing.rs | 3 +-- tarpc/src/transport/channel.rs | 15 +++++++-------- tarpc/tests/service_functional.rs | 3 +-- 6 files changed, 27 insertions(+), 26 deletions(-) diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 61a2e32a0..e7c325d42 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -543,18 +543,21 @@ impl ServiceGenerator<'_> { .. } = self; - let rpc_fns = rpcs.iter().zip(return_types.iter()).map( - |( - RpcMethod { - attrs, ident, args, .. - }, - output, - )| { - quote! { - #( #attrs )* - async fn #ident(self, context: &mut Self::Context, #( #args ),*) -> #output; - } - }, + let rpc_fns = rpcs + .iter() + .zip(return_types.iter()) + .map( + |( + RpcMethod { + attrs, ident, args, .. + }, + output, + )| { + quote! { + #( #attrs )* + async fn #ident(self, context: &mut Self::Context, #( #args ),*) -> #output; + } + }, ); let stub_doc = format!("The stub trait for service [`{service_ident}`]."); diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index a1b50c72e..d4a6611e0 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -111,7 +111,6 @@ mod absolute_to_relative_time { } } - assert_impl_all!(Context: Send, Sync); fn ten_seconds_from_now() -> Instant { diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 06385b15c..76d9a1815 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -303,7 +303,9 @@ impl RequestName for Arc where Req: RequestName, { - fn name(&self) -> &str { self.as_ref().name() } + fn name(&self) -> &str { + self.as_ref().name() + } } impl RequestName for Box diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index a92b50fc2..39eabdaf5 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -114,8 +114,7 @@ impl FakeChannel>, R } impl FakeChannel<(), ()> { - pub fn default() - -> FakeChannel>, Response> + pub fn default() -> FakeChannel>, Response> { let (request_cancellation, canceled_requests) = cancellations(); FakeChannel { diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 47f3e4928..35c81fb1e 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -192,14 +192,13 @@ mod tests { stream::once(future::ready(server_channel)) .map(BaseChannel::with_defaults) .execute(serve(|_ctx: &mut context::Context, request: String| async move { - request.parse::().map_err(|_| { - ServerError::new( - io::ErrorKind::InvalidInput, - format!("{request:?} is not an int"), - ) - }) - }.boxed() - )) + request.parse::().map_err(|_| { + ServerError::new( + io::ErrorKind::InvalidInput, + format!("{request:?} is not an int"), + ) + }) + }.boxed())) .for_each(|channel| async move { tokio::spawn(channel.for_each(|response| response)); }), diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 9d3c70e37..559521414 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -36,7 +36,6 @@ impl Service for Server { #[tokio::test] async fn sequential() { let (tx, rx) = tarpc::transport::channel::unbounded(); - let client = client::new(client::Config::default(), tx).spawn(); let channel = BaseChannel::with_defaults(rx); tokio::spawn( @@ -270,7 +269,7 @@ async fn counter() -> anyhow::Result<()> { let (tx, rx) = channel::unbounded(); - tokio::task::spawn(async move { + tokio::task::spawn(async { let mut requests = BaseChannel::with_defaults(rx).requests(); let mut counter = CountService(0); From 835d92c41689b72e2c9869496b932c2e0077d098 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Wed, 26 Nov 2025 21:14:39 +0100 Subject: [PATCH 23/26] cleanup --- tarpc/src/client.rs | 116 ++++++++++++++++++-------------------------- tarpc/src/server.rs | 115 +++++++++++-------------------------------- 2 files changed, 77 insertions(+), 154 deletions(-) diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 90f7cac45..74031d969 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -144,10 +144,7 @@ where ); shared_context.trace_context.new_child() }); - span.record( - "rpc.trace_id", - tracing::field::display(shared_context.trace_id()), - ); + span.record("rpc.trace_id", tracing::field::display(shared_context.trace_id()), ); let (response_completion, mut response) = oneshot::channel(); let request_id = u64::try_from(self.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); @@ -310,7 +307,9 @@ where C: Transport, Response>, ClientCtx: ExtractContext + From, { - fn in_flight_requests<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests { + fn in_flight_requests<'a>( + self: &'a mut Pin<&mut Self>, + ) -> &'a mut InFlightRequests { self.as_mut().project().in_flight_requests } @@ -327,10 +326,7 @@ where .map_err(|e| ChannelError::Ready(Arc::new(e))) } - fn start_send( - self: &mut Pin<&mut Self>, - message: ClientMessage, - ) -> Result<(), C::Error> { + fn start_send(self: &mut Pin<&mut Self>, message: ClientMessage) -> Result<(), C::Error> { self.transport_pin_mut().start_send(message) } @@ -539,17 +535,11 @@ where let request = ClientMessage::Request(Request { id: request_id, message: request, - context: ctx.into(), + context: ctx.into(), //TODO: <-- This will actually initialize an empty client context, and the transport will never see the original }); self.in_flight_requests() - .insert_request( - request_id, - trace_context, - deadline, - span.clone(), - response_completion, - ) + .insert_request(request_id, trace_context, deadline, span.clone(), response_completion) .expect("Request IDs should be unique"); match self.start_send(request) { Ok(()) => tracing::debug!("SendRequest"), @@ -571,11 +561,10 @@ where self: &mut Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>>> { - let (trace_context, span, request_id) = - match ready!(self.as_mut().poll_next_cancellation(cx)?) { - Some(triple) => triple, - None => return Poll::Ready(None), - }; + let (trace_context, span, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) { + Some(triple) => triple, + None => return Poll::Ready(None), + }; let _entered = span.enter(); let cancel = ClientMessage::Cancel { @@ -704,8 +693,8 @@ where /// the lifecycle of the request. #[derive(Debug)] struct DispatchRequest { + ///TODO: this should be a &mut ClientCtx pub ctx: context::Context, - ///TODO: <-- this should be a &mut ClientContext pub span: Span, pub request_id: u64, pub request: Req, @@ -717,7 +706,12 @@ mod tests { use super::{ Channel, DispatchRequest, RequestDispatch, ResponseGuard, RpcError, cancellations, }; - use crate::{ChannelError, ClientMessage, Response, client::{Config, in_flight_requests::InFlightRequests}, transport::{self, channel::UnboundedChannel}, context}; + use crate::{ + ChannelError, ClientMessage, Response, + client::{Config, in_flight_requests::InFlightRequests}, + context, + transport::{self, channel::UnboundedChannel} + }; use assert_matches::assert_matches; use futures::{prelude::*, task::*}; use std::{ @@ -747,13 +741,7 @@ mod tests { dispatch .in_flight_requests - .insert_request( - 0, - context.trace_context, - context.deadline, - Span::current(), - tx, - ) + .insert_request(0, context.trace_context, context.deadline, Span::current(), tx) .unwrap(); server_channel .send(Response { @@ -899,8 +887,7 @@ mod tests { #[tokio::test] async fn test_permit_before_transport_error() { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); - let (mut dispatch, mut channel, mut cx) = - set_up_always_err::(TransportError::Flush); + let (mut dispatch, mut channel, mut cx) = set_up_always_err::(TransportError::Flush); let (tx, mut rx) = oneshot::channel(); // reserve succeeds let permit = reserve_for_send(&mut channel, tx, &mut rx).await; @@ -920,9 +907,7 @@ mod tests { let (dispatch, channel, _server_channel) = set_up::(); drop(dispatch); // error on send - let resp = channel - .call(&mut context::current(), "hi".to_string()) - .await; + let resp = channel.call(&mut context::current(), "hi".to_string()).await; assert_matches!(resp, Err(RpcError::Shutdown)); } @@ -1001,18 +986,13 @@ mod tests { fn set_up_always_err( cause: TransportError, ) -> ( - Pin< - Box< - RequestDispatch>, - >, - >, + Pin>>>, Channel, Context<'static>, ) { let (to_dispatch, pending_requests) = mpsc::channel(1); let (cancellation, canceled_requests) = cancellations(); - let transport: AlwaysErrorTransport = - AlwaysErrorTransport(cause, PhantomData); + let transport: AlwaysErrorTransport = AlwaysErrorTransport(cause, PhantomData); let dispatch = Box::pin(RequestDispatch:: { transport: transport.fuse(), pending_requests, @@ -1132,31 +1112,6 @@ mod tests { (Box::pin(dispatch), channel, server_channel) } - async fn send_request<'a, ClientCtx>( - channel: &'a mut Channel, - request: &str, - response_completion: oneshot::Sender>, - response: &'a mut oneshot::Receiver>, - ) -> ResponseGuard<'a, String> { - let request_id = - u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); - let request = DispatchRequest { - ctx: context::current(), - span: Span::current(), - request_id, - request: request.to_string(), - response_completion, - }; - let response_guard = ResponseGuard { - response, - cancellation: &channel.cancellation, - request_id, - cancel: true, - }; - channel.to_dispatch.send(request).await.unwrap(); - response_guard - } - async fn reserve_for_send<'a, ClientCtx>( channel: &'a mut Channel, response_completion: oneshot::Sender>, @@ -1183,6 +1138,31 @@ mod tests { } } + async fn send_request<'a, ClientCtx>( + channel: &'a mut Channel, + request: &str, + response_completion: oneshot::Sender>, + response: &'a mut oneshot::Receiver>, + ) -> ResponseGuard<'a, String> { + let request_id = + u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); + let request = DispatchRequest { + ctx: context::current(), + span: Span::current(), + request_id, + request: request.to_string(), + response_completion, + }; + let response_guard = ResponseGuard { + response, + cancellation: &channel.cancellation, + request_id, + cancel: true, + }; + channel.to_dispatch.send(request).await.unwrap(); + response_guard + } + async fn send_response( channel: &mut UnboundedChannel< ClientMessage, diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index a7560132b..7e08db475 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -9,7 +9,7 @@ use crate::{ ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport, cancellations::{CanceledRequests, RequestCancellation, cancellations}, - context, context::SpanExt, + context::{self, SpanExt}, trace, util::TimeUntil, }; @@ -59,10 +59,7 @@ impl Default for Config { impl Config { /// Returns a channel backed by `transport` and configured with `self`. - pub fn channel( - self, - transport: T, - ) -> BaseChannel + pub fn channel(self, transport: T) -> BaseChannel where T: Transport, ClientMessage>, ServerCtx: ExtractContext, @@ -84,11 +81,7 @@ pub trait Serve { type Resp; /// Responds to a single request. - async fn serve( - self, - ctx: &mut Self::ServerCtx, - req: Self::Req, - ) -> Result; + async fn serve(self, ctx: &mut Self::ServerCtx, req: Self::Req) -> Result; } /// A Serve wrapper around a Fn. @@ -116,10 +109,8 @@ impl Copy for ServeFn where F: /// Result>`. pub fn serve(f: F) -> ServeFn where - for<'a> F: FnOnce( - &'a mut ServerCtx, - Req, - ) -> Pin> + 'a + Send>>, + // This should be -> impl Future<...>, but there is no syntax to express the 'a lifetime. + for<'a> F: FnOnce(&'a mut ServerCtx, Req) -> Pin> + 'a + Send>>, { ServeFn { f, @@ -130,6 +121,7 @@ where impl Serve for ServeFn where Req: RequestName, + // This should be -> impl Future<...>, but there is no syntax to express the 'a lifetime. for<'a> F: FnOnce( &'a mut ServerCtx, Req, @@ -314,10 +306,7 @@ pub struct TrackedRequest { /// created by [`BaseChannel`]. pub trait Channel where - Self: Transport< - Response::Resp>, - TrackedRequest::Req>, - >, + Self: Transport::Resp>, TrackedRequest::Req>>, { /// Type of request item. type Req; @@ -553,8 +542,7 @@ where } } -impl Sink> - for BaseChannel +impl Sink> for BaseChannel where T: Transport, ClientMessage>, T::Error: Error, @@ -569,10 +557,7 @@ where .map_err(|e| ChannelError::Ready(Arc::new(e))) } - fn start_send( - mut self: Pin<&mut Self>, - response: Response, - ) -> Result<(), Self::Error> { + fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { if let Some(span) = self .in_flight_requests_mut() .remove_request(response.request_id) @@ -1018,19 +1003,7 @@ mod tests { }; fn test_channel() -> ( - Pin< - Box< - BaseChannel< - Req, - Resp, - UnboundedChannel< - ClientMessage, - Response, - >, - context::Context, - >, - >, - >, + Pin, Response>, context::Context>>>, UnboundedChannel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::unbounded(); @@ -1038,21 +1011,7 @@ mod tests { } fn test_requests() -> ( - Pin< - Box< - Requests< - BaseChannel< - Req, - Resp, - UnboundedChannel< - ClientMessage, - Response, - >, - context::Context, - >, - >, - >, - >, + Pin, Response>, context::Context>>>>, UnboundedChannel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::unbounded(); @@ -1065,21 +1024,7 @@ mod tests { fn test_bounded_requests( capacity: usize, ) -> ( - Pin< - Box< - Requests< - BaseChannel< - Req, - Resp, - channel::Channel< - ClientMessage, - Response, - >, - context::Context, - >, - >, - >, - >, + Pin, Response>, context::Context>>>>, channel::Channel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::bounded(capacity); @@ -1107,10 +1052,7 @@ mod tests { #[tokio::test] async fn test_serve() { let serve = serve(|_, i| async move { Ok(i) }.boxed()); - assert_matches!( - serve.serve(&mut context::current(), 7).await, - Ok(7) - ); + assert_matches!(serve.serve(&mut context::current(), 7).await, Ok(7)); } #[tokio::test] @@ -1120,7 +1062,11 @@ mod tests { where ServerCtx: ExtractContext, { - async fn before(&mut self, ctx: &mut ServerCtx, _: &Req) -> Result<(), ServerError> { + async fn before( + &mut self, + ctx: &mut ServerCtx, + _: &Req + ) -> Result<(), ServerError> { let mut inner = ctx.extract(); inner.deadline = self.0; ctx.update(inner); @@ -1131,13 +1077,10 @@ mod tests { let some_time = Instant::now() + Duration::from_secs(37); let some_other_time = Instant::now() + Duration::from_secs(83); - let serve = serve(move |ctx: &mut context::Context, i| { - async move { - assert_eq!(ctx.deadline, some_time); - Ok(i) - } - .boxed() - }); + let serve = serve(move |ctx: &mut context::Context, i| async move { + assert_eq!(ctx.deadline, some_time); + Ok(i) + }.boxed()); let deadline_hook = serve.before(SetDeadline(some_time)); let mut ctx = context::current(); ctx.deadline = some_other_time; @@ -1160,7 +1103,11 @@ mod tests { } } impl BeforeRequest for PrintLatency { - async fn before(&mut self, _: &mut ServerCtx, _: &Req) -> Result<(), ServerError> { + async fn before( + &mut self, + _: &mut ServerCtx, + _: &Req + ) -> Result<(), ServerError> { self.start = Instant::now(); Ok(()) } @@ -1185,9 +1132,7 @@ mod tests { let deadline_hook = serve.before(|_: &mut context::Context, _: &i32| async { Err(ServerError::new(io::ErrorKind::Other, "oops".into())) }); - let resp: Result = deadline_hook - .serve(&mut context::current(), 7) - .await; + let resp: Result = deadline_hook.serve(&mut context::current(), 7).await; assert_matches!(resp, Err(_)); Ok(()) } @@ -1393,9 +1338,7 @@ mod tests { Poll::Ready(Some(Ok(request))) => request, result => panic!("Unexpected result: {result:?}"), }; - request - .execute(serve(|_, _| async { Ok(()) }.boxed())) - .await; + request.execute(serve(|_, _| async { Ok(()) }.boxed())).await; assert!( requests .as_mut() From 9d9f6461c47891ed7ae858f42eec768cd80bca6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Fri, 26 Dec 2025 12:55:57 +0100 Subject: [PATCH 24/26] allow shared context sent between the client and server to be customizable as well. Since this is part of the contract between client and server, it is implemented as an option to the service macro. --- example-service/src/server.rs | 2 +- plugins/src/lib.rs | 92 ++++-- plugins/tests/service.rs | 6 +- tarpc/examples/compression.rs | 6 +- tarpc/examples/custom_context.rs | 262 ++++++++++++++++++ tarpc/examples/custom_transport.rs | 4 +- tarpc/examples/pubsub.rs | 68 ++--- tarpc/examples/readme.rs | 11 +- tarpc/examples/tls_over_tcp.rs | 2 +- tarpc/examples/tracing.rs | 36 +-- tarpc/src/client.rs | 247 ++++++++++------- tarpc/src/client/in_flight_requests.rs | 32 ++- tarpc/src/client/stub.rs | 18 +- tarpc/src/client/stub/load_balance.rs | 8 +- tarpc/src/context.rs | 118 +++++++- tarpc/src/lib.rs | 6 +- tarpc/src/server.rs | 177 ++++++++---- .../src/server/limits/requests_per_channel.rs | 16 +- tarpc/src/server/request_hook.rs | 4 +- tarpc/src/server/request_hook/before.rs | 12 +- .../server/request_hook/before_and_after.rs | 3 +- tarpc/src/server/testing.rs | 30 +- tarpc/src/trace.rs | 9 +- tarpc/src/transport/channel.rs | 21 +- .../compile_fail/must_use_request_dispatch.rs | 4 +- .../must_use_request_dispatch.stderr | 6 +- .../compile_fail/serde1/opt_out_serde.stderr | 8 +- tarpc/tests/dataservice.rs | 2 +- tarpc/tests/service_functional.rs | 11 +- 29 files changed, 889 insertions(+), 332 deletions(-) create mode 100644 tarpc/examples/custom_context.rs diff --git a/example-service/src/server.rs b/example-service/src/server.rs index a8e3324fc..37bdb0f42 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -36,7 +36,7 @@ struct Flags { struct HelloServer(SocketAddr); impl World for HelloServer { - type Context = context::Context; + type Context = context::DefaultContext; async fn hello(self, _: &mut Self::Context, name: String) -> String { let sleep_time = Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng())); diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index e7c325d42..46375fce0 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -11,8 +11,8 @@ use proc_macro::TokenStream; use proc_macro2::TokenStream as TokenStream2; use quote::{ToTokens, format_ident, quote}; use syn::{ - AttrStyle, Attribute, Expr, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, PatType, Path, - ReturnType, Token, Type, Visibility, braced, + AttrStyle, Attribute, Expr, ExprLit, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, PatType, + Path, ReturnType, Token, Type, Visibility, braced, ext::IdentExt, parenthesized, parse::{Parse, ParseStream}, @@ -139,6 +139,7 @@ impl Parse for RpcMethod { #[derive(Default)] struct DeriveMeta { derive: Option, + shared_context: Option, warnings: Vec, } @@ -251,6 +252,37 @@ impl Parse for DeriveMeta { ), } derive_serde.push(meta); + } else if segment.ident == "shared_context" { + let Expr::Lit(ExprLit { + lit: Lit::Str(ref v), + .. + }) = meta.value + else { + extend_errors!( + result, + syn::Error::new( + meta.span(), + "tarpc::service requires a literal string value for the shared_context attribute" + ) + ); + continue; + }; + + let Ok(ty) = syn::parse_str(&v.value()) else { + extend_errors!( + result, + syn::Error::new( + meta.span(), + "tarpc::service could not parse the value of the shared_context attribute as a type" + ) + ); + continue; + }; + + result = result.map(|d| DeriveMeta { + shared_context: Some(ty), + ..d + }) } else { extend_errors!( result, @@ -371,7 +403,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// # Example /// /// ```no_run -/// use tarpc::{client, context, transport, service, server::{self, Channel}, context::Context}; +/// use tarpc::{client, context, transport, service, server::{self, Channel}, context::DefaultContext}; /// /// #[service] /// pub trait Calculator { @@ -393,11 +425,17 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// // A client can be made like so: /// let client = CalculatorClient::new(client::Config::default(), client_side); /// +/// // You would usually call it like so. +/// #[cfg(feature = "tokio1")] +/// let client = client.spawn(); +/// #[cfg(not(feature = "tokio1"))] +/// let client = client.client; // Don't forget to run the dispatch future! +/// /// // And a server like so: /// #[derive(Clone)] /// struct CalculatorServer; /// impl Calculator for CalculatorServer { -/// type Context = context::Context; +/// type Context = context::DefaultContext; /// async fn add(self, context: &mut Self::Context, a: i32, b: i32) -> i32 { /// a + b /// } @@ -406,10 +444,13 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// // You would usually spawn on an async runtime. /// let server = server::BaseChannel::with_defaults(server_side); /// let _ = server.execute(CalculatorServer.serve()); +/// +/// let _ = client.add(&mut context::current(), 1,2); /// ``` #[proc_macro_attribute] pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { let derive_meta = parse_macro_input!(attr as DeriveMeta); + let unit_type: &Type = &parse_quote!(()); let Service { ref attrs, @@ -424,6 +465,9 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { .collect(); let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::>(); + let shared_context = derive_meta + .shared_context + .unwrap_or(parse_quote!(::tarpc::context::DefaultContext)); let derives = match derive_meta.derive.as_ref() { Some(Derive::Explicit(paths)) => { if !paths.is_empty() { @@ -498,6 +542,7 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { .map(|(rpc, name)| Ident::new(name, rpc.ident.span())) .collect::>(), derives: derives.as_ref(), + shared_context: &shared_context, warnings: &derive_meta.warnings, } .into_token_stream() @@ -525,6 +570,7 @@ struct ServiceGenerator<'a> { return_types: &'a [&'a Type], arg_pats: &'a [Vec<&'a Pat>], derives: Option<&'a TokenStream2>, + shared_context: &'a Type, warnings: &'a [TokenStream2], } @@ -540,31 +586,29 @@ impl ServiceGenerator<'_> { request_ident, response_ident, server_ident, + shared_context, .. } = self; - let rpc_fns = rpcs - .iter() - .zip(return_types.iter()) - .map( - |( - RpcMethod { - attrs, ident, args, .. - }, - output, - )| { - quote! { - #( #attrs )* - async fn #ident(self, context: &mut Self::Context, #( #args ),*) -> #output; - } - }, + let rpc_fns = rpcs.iter().zip(return_types.iter()).map( + |( + RpcMethod { + attrs, ident, args, .. + }, + output, + )| { + quote! { + #( #attrs )* + async fn #ident(self, context: &mut Self::Context, #( #args ),*) -> #output; + } + }, ); let stub_doc = format!("The stub trait for service [`{service_ident}`]."); quote! { #( #attrs )* #vis trait #service_ident: ::core::marker::Sized { - type Context: ::tarpc::context::ExtractContext<::tarpc::context::Context>; + type Context: ::tarpc::context::ExtractContext<#shared_context>; // = ::tarpc::context::DefaultContext; TODO: Add associated type default once https://github.com/rust-lang/rust/issues/29661 is stabilized #( #rpc_fns )* @@ -705,6 +749,7 @@ impl ServiceGenerator<'_> { client_ident, request_ident, response_ident, + shared_context, .. } = self; @@ -715,10 +760,10 @@ impl ServiceGenerator<'_> { /// [Futures](::core::future::Future). #vis struct #client_ident< ClientCtx, - Stub = ::tarpc::client::Channel<#request_ident, #response_ident, ClientCtx> + Stub = ::tarpc::client::Channel<#request_ident, #response_ident, ClientCtx, #shared_context> >(Stub, ::std::marker::PhantomData); - impl ::std::clone::Clone for #client_ident { + impl ::std::clone::Clone for #client_ident { fn clone(&self) -> Self { Self(self.0.clone(), ::std::marker::PhantomData) } @@ -732,6 +777,7 @@ impl ServiceGenerator<'_> { vis, request_ident, response_ident, + shared_context, .. } = self; @@ -741,7 +787,7 @@ impl ServiceGenerator<'_> { #vis fn new(config: ::tarpc::client::Config, transport: T) -> ::tarpc::client::NewClient< Self, - ::tarpc::client::RequestDispatch<#request_ident, #response_ident, ClientCtx, T> + ::tarpc::client::RequestDispatch<#request_ident, #response_ident, ClientCtx, #shared_context, T> > where T: ::tarpc::Transport<::tarpc::ClientMessage, ::tarpc::Response> diff --git a/plugins/tests/service.rs b/plugins/tests/service.rs index 2e450095c..7473cac3b 100644 --- a/plugins/tests/service.rs +++ b/plugins/tests/service.rs @@ -12,7 +12,7 @@ fn att_service_trait() { } impl Foo for () { - type Context = context::Context; + type Context = context::DefaultContext; async fn two_part(self, _: &mut Self::Context, s: String, i: i32) -> (String, i32) { (s, i) } @@ -38,7 +38,7 @@ fn raw_idents() { } impl r#trait for () { - type Context = context::Context; + type Context = context::DefaultContext; async fn r#await( self, _: &mut Self::Context, @@ -66,7 +66,7 @@ fn service_with_cfg_rpc() { } impl Foo for () { - type Context = context::Context; + type Context = context::DefaultContext; async fn foo(self, _: &mut Self::Context) {} } } diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index c96014eea..3eebe963b 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -109,7 +109,7 @@ pub trait World { struct HelloServer; impl World for HelloServer { - type Context = context::Context; + type Context = context::DefaultContext; async fn hello(self, _: &mut Self::Context, name: String) -> String { format!("Hey, {name}!") } @@ -136,7 +136,9 @@ async fn main() -> anyhow::Result<()> { println!( "{}", - client.hello(&mut context::current(), "friend".into()).await? + client + .hello(&mut context::current(), "friend".into()) + .await? ); Ok(()) } diff --git a/tarpc/examples/custom_context.rs b/tarpc/examples/custom_context.rs new file mode 100644 index 000000000..64978b063 --- /dev/null +++ b/tarpc/examples/custom_context.rs @@ -0,0 +1,262 @@ +// Copyright 2018 Google LLC +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. +#![deny(warnings, unused, dead_code)] + +use std::collections::HashMap; +use std::ops::Add; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use futures::{SinkExt, TryStreamExt, StreamExt, FutureExt}; +use serde::{Deserialize, Serialize}; +use tokio::sync::Mutex; +use tarpc::{client, server::{self, Channel}, trace, ClientMessage, Request, Response, ServerError, Transport}; +use tarpc::context::{ExtractContext, SharedContext}; +use tarpc::server::request_hook::{AfterRequest, BeforeRequest, RequestHook}; +use tarpc::transport::channel::UnboundedChannel; + +#[derive(Serialize, Deserialize, Clone)] +struct CustomContext { + #[serde(with = "absolute_to_relative_time")] + pub deadline: Instant, + pub trace_context: trace::Context, + pub session_id: Option +} + +impl SharedContext for CustomContext { + fn deadline(&self) -> Instant { + self.deadline + } + + fn trace_context(&self) -> trace::Context { + self.trace_context + } + + fn set_trace_context(&mut self, trace_context: trace::Context) { + self.trace_context = trace_context; + } +} + +#[derive(Clone, Debug)] +struct ClientContext { + pub session_id: Option, + pub delay_sending_by_seconds: u32, +} + +struct ServerContext { + pub deadline: Instant, + pub trace_context: trace::Context, + pub session_id: Option, + pub balance: u64, +} + +impl ExtractContext for ClientContext { + fn extract(&self) -> CustomContext { + CustomContext { + deadline: Instant::now().add(Duration::from_secs(60)), + trace_context: Default::default(), + session_id: self.session_id, + } + } + + fn update(&mut self, value: CustomContext) { + self.session_id = value.session_id; + } +} + +impl ExtractContext for ServerContext { + fn extract(&self) -> CustomContext { + CustomContext { + deadline: self.deadline, + trace_context: self.trace_context, + session_id: self.session_id, + } + } + + fn update(&mut self, value: CustomContext) { + self.deadline = value.deadline; + self.trace_context = value.trace_context; + self.session_id = value.session_id; + } +} + +/// This is the service definition. It looks a lot like a trait definition. +/// It defines one RPC, hello, which takes one arg, name, and returns a String. +#[tarpc::service(shared_context = "CustomContext")] +pub trait World { + async fn create_session() -> (); + async fn increase_balance(credits: u32) -> (); + async fn hello(name: String) -> Result; +} + +/// This is the type that implements the generated World trait. It is the business logic +/// and is used to start the server. +#[derive(Clone)] +struct HelloServer; + +impl World for HelloServer { + type Context = ServerContext; + + async fn create_session(self, ctx: &mut Self::Context) -> () { + ctx.session_id = Some(42); + ctx.balance = 0; + } + + async fn increase_balance(self, ctx: &mut Self::Context, credits: u32) -> () { + ctx.balance = ctx.balance + credits as u64; + } + + async fn hello(self, ctx: &mut Self::Context, name: String) -> Result { + if ctx.session_id != Some(42) { + Err("Session not yet initialized!")? + } + + if ctx.balance == 0 { + Err("Give me more money")? + } + + ctx.balance = ctx.balance - 1; + + Ok(format!("Hello, {name}!")) + } +} + +async fn spawn(fut: impl Future + Send + 'static) { + tokio::spawn(fut); +} + + + + +#[derive(Clone)] +struct SessionHook { + balances: Arc>> +} + +impl BeforeRequest for SessionHook { + async fn before(&mut self, ctx: &mut ServerContext, _req: &WorldRequest) -> Result<(), ServerError> { + if let Some(id) = ctx.session_id { + let balances = self.balances.lock().await; + ctx.balance = *balances.get(&id).unwrap_or(&0u64) + } + + Ok(()) + } +} + +impl AfterRequest for SessionHook { + async fn after(&mut self, ctx: &mut ServerContext, resp: &mut Result) { + if resp.is_ok() && let Some(id) = ctx.session_id { + let mut balances = self.balances.lock().await; + + let b = balances.entry(id).or_insert(0); + *b = ctx.balance; + } + } +} + + + + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let (client_transport, server_transport) = create_channel(); + + let client_transport = client_transport.with(|f: ClientMessage| async move { + + if let ClientMessage::Request(Request {ref context, ref message, .. }) = f { + println!("msg = {:?}, ctx = {:?}", message, context); + tokio::time::sleep(Duration::from_secs(context.delay_sending_by_seconds as u64)).await; + } + + Ok(f) + }.boxed()); + + let server = server::BaseChannel::with_defaults(server_transport); + let hook = SessionHook { balances: Arc::new(Mutex::new(HashMap::new())) }; + tokio::spawn(server.execute(HelloServer.serve().before_and_after(hook)).for_each(spawn)); + + // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` + // that takes a config and any Transport as input. + let client = WorldClient::new(client::Config::default(), client_transport).spawn(); + + // The client has an RPC method for each RPC defined in the annotated trait. It takes the same + // args as defined, with the addition of a Context, which is always the first arg. The Context + // specifies a deadline and trace information which can be helpful in debugging requests. + + let mut client_context: ClientContext = ClientContext { + session_id: None, + delay_sending_by_seconds: 1, + }; + + let hello = client.hello(&mut client_context, "Stan".to_string()).await?; + + assert_eq!(hello, Err("Session not yet initialized!".to_string())); + + let _ = client.create_session(&mut client_context).await?; + let hello = client.hello(&mut client_context, "Stan".to_string()).await?; + assert_eq!(hello, Err("Give me more money".to_string())); + + let _ = client.increase_balance(&mut client_context, 2u32).await?; + + let hello = client.hello(&mut client_context, "Stan".to_string()).await?; + assert_eq!(hello, Ok("Hello, Stan!".to_string())); + + let hello = client.hello(&mut client_context, "Frank".to_string()).await?; + assert_eq!(hello, Ok("Hello, Frank!".to_string())); + + let hello = client.hello(&mut client_context, "Joshua".to_string()).await?; + assert_eq!(hello, Err("Give me more money".to_string())); + + Ok(()) +} + +//*** Helper functions below ***// + +mod absolute_to_relative_time { + pub use serde::{Deserialize, Deserializer, Serialize, Serializer}; + pub use std::time::{Duration, Instant}; + + pub fn serialize(deadline: &Instant, serializer: S) -> Result + where + S: Serializer, + { + let deadline = deadline.duration_since(Instant::now()); + deadline.serialize(serializer) + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let deadline = Duration::deserialize(deserializer)?; + Ok(Instant::now() + deadline) + } +} + +fn map_request_context(req: ClientMessage, f: impl FnOnce(Ctx) -> Ctx2) -> ClientMessage { + match req { + ClientMessage::Request(Request { context, id, message}) => ClientMessage::Request(Request { context: f(context), id, message }), + ClientMessage::Cancel { trace_context, request_id } => ClientMessage::Cancel { trace_context, request_id }, + _ => unimplemented!() + } +} + +fn map_response_context(res: Response, f: impl FnOnce(Ctx) -> Ctx2) -> Response { + Response { + request_id: res.request_id, + context: f(res.context), + message: res.message + } +} + +fn create_channel() -> (impl Transport, Response>, impl Transport, ClientMessage>) { + let (client, server): (UnboundedChannel, ClientMessage>, UnboundedChannel, Response>) = tarpc::transport::channel::unbounded(); + + let client = client.with(|m| futures::future::ok(map_request_context(m, |c: ClientContext| c.extract()))).map_ok(|r| map_response_context(r, |c: CustomContext| ClientContext { session_id: c.session_id, delay_sending_by_seconds: 0 })); + let server = server.with(|r| futures::future::ok(map_response_context(r, |c: ServerContext| c.extract()))).map_ok(|m| map_request_context(m, |c| ServerContext { deadline: c.deadline, trace_context: c.trace_context, session_id: c.session_id, balance: 0 })); + + (client, server) +} \ No newline at end of file diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index 7fe32bfa7..aa62baf99 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -6,7 +6,7 @@ #![deny(warnings, unused, dead_code)] use futures::prelude::*; -use tarpc::{context}; +use tarpc::context; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; @@ -22,7 +22,7 @@ pub trait PingService { struct Service; impl PingService for Service { - type Context = context::Context; + type Context = context::DefaultContext; async fn ping(self, _: &mut Self::Context) {} } diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 6c0099a97..6fc08d7b5 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -41,7 +41,6 @@ use futures::{ }; use opentelemetry::trace::TracerProvider as _; use publisher::Publisher as _; -use serde::de::DeserializeOwned; use std::{ collections::HashMap, error::Error, @@ -49,9 +48,8 @@ use std::{ net::SocketAddr, sync::{Arc, Mutex, RwLock}, }; -use serde::Serialize; use subscriber::Subscriber as _; -use tarpc::context::{ExtractContext}; +use tarpc::context::DefaultContext; use tarpc::{ client, context, serde_transport::tcp, @@ -84,7 +82,7 @@ struct Subscriber { } impl subscriber::Subscriber for Subscriber { - type Context = context::Context; + type Context = context::DefaultContext; async fn topics(self, _: &mut Self::Context) -> Vec { self.topics.clone() } @@ -138,13 +136,14 @@ struct Subscription { } #[derive(Debug)] -struct Publisher { +struct Publisher { clients: Arc>>, - subscriptions: - Arc>>>>, + subscriptions: Arc< + RwLock>>>, + >, } -impl Clone for Publisher { +impl Clone for Publisher { fn clone(&self) -> Self { Publisher { clients: self.clients.clone(), @@ -162,17 +161,7 @@ async fn spawn(fut: impl Future + Send + 'static) { tokio::spawn(fut); } -// TODO: Remove serde bounds here -impl Publisher -where - ClientCtx: ExtractContext - + From - + Serialize - + DeserializeOwned - + Send - + Sync - + 'static, -{ +impl Publisher { async fn start(self) -> io::Result { let mut connecting_publishers = tcp::listen("localhost:0", Json::default).await?; @@ -231,11 +220,10 @@ where async fn initialize_subscription( &mut self, subscriber_addr: SocketAddr, - subscriber: subscriber::SubscriberClient, + subscriber: subscriber::SubscriberClient, ) { // Populate the topics - if let Ok(topics) = subscriber.topics(&mut ClientCtx::from(context::current())).await - { + if let Ok(topics) = subscriber.topics(&mut context::current()).await { self.clients.lock().unwrap().insert( subscriber_addr, Subscription { @@ -287,12 +275,10 @@ where } } -impl publisher::Publisher for Publisher -where - ClientCtx: ExtractContext + From + Send + Sync + 'static, -{ - type Context = ClientCtx; - async fn publish(self, _: &mut Self::Context, topic: String, message: String) { +impl publisher::Publisher for Publisher { + type Context = DefaultContext; + + async fn publish(self, ctx: &mut Self::Context, topic: String, message: String) { info!("received message to publish."); let mut subscribers = match self.subscriptions.read().unwrap().get(&topic) { None => return, @@ -301,8 +287,10 @@ where let mut publications = Vec::new(); for client in subscribers.values_mut() { publications.push(async { - let mut context = ClientCtx::from(context::current()); - client.receive(&mut context, topic.clone(), message.clone(), ).await + let mut context = ctx.clone(); + client + .receive(&mut context, topic.clone(), message.clone()) + .await }); } // Ignore failing subscribers. In a real pubsub, you'd want to continually retry until @@ -348,7 +336,7 @@ pub fn init_tracing( async fn main() -> anyhow::Result<()> { let tracer_provider = init_tracing("Pub/Sub")?; - let addrs = Publisher:: { + let addrs = Publisher { clients: Arc::new(Mutex::new(HashMap::new())), subscriptions: Arc::new(RwLock::new(HashMap::new())), } @@ -378,17 +366,29 @@ async fn main() -> anyhow::Result<()> { .await?; publisher - .publish(&mut context::current(), "cool shorts".into(), "hello to all".into()) + .publish( + &mut context::current(), + "cool shorts".into(), + "hello to all".into(), + ) .await?; publisher - .publish(&mut context::current(), "history".into(), "napoleon".to_string()) + .publish( + &mut context::current(), + "history".into(), + "napoleon".to_string(), + ) .await?; drop(_subscriber0); publisher - .publish(&mut context::current(), "cool shorts".into(), "hello to who?".into(), ) + .publish( + &mut context::current(), + "cool shorts".into(), + "hello to who?".into(), + ) .await?; tracer_provider.shutdown()?; diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index f8f298921..acbade9be 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -6,7 +6,10 @@ #![deny(warnings, unused, dead_code)] use futures::prelude::*; -use tarpc::{client, context, server::{self, Channel}}; +use tarpc::{ + client, context, + server::{self, Channel}, +}; /// This is the service definition. It looks a lot like a trait definition. /// It defines one RPC, hello, which takes one arg, name, and returns a String. @@ -21,7 +24,7 @@ pub trait World { struct HelloServer; impl World for HelloServer { - type Context = context::Context; + type Context = context::DefaultContext; async fn hello(self, _: &mut Self::Context, name: String) -> String { format!("Hello, {name}!") } @@ -45,7 +48,9 @@ async fn main() -> anyhow::Result<()> { // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context // specifies a deadline and trace information which can be helpful in debugging requests. - let hello = client.hello(&mut context::current(), "Stim".to_string()).await?; + let hello = client + .hello(&mut context::current(), "Stim".to_string()) + .await?; println!("{hello}"); diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index 0ba8f2581..e7084fe84 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -34,7 +34,7 @@ pub trait PingService { struct Service; impl PingService for Service { - type Context = context::Context; + type Context = context::DefaultContext; async fn ping(self, _: &mut Self::Context) -> String { "🔒".to_owned() } diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index 0789d0a43..e281f39fd 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -12,7 +12,6 @@ use crate::{ }; use futures::{future, prelude::*}; use opentelemetry::trace::TracerProvider as _; -use std::marker::PhantomData; use std::{ io, sync::{ @@ -20,7 +19,7 @@ use std::{ atomic::{AtomicBool, Ordering}, }, }; -use tarpc::context::{ExtractContext}; +use tarpc::context::DefaultContext; use tarpc::{ ClientMessage, RequestName, Response, ServerError, Transport, client::{ @@ -58,27 +57,25 @@ pub mod double { struct AddServer; impl AddService for AddServer { - type Context = context::Context; + type Context = context::DefaultContext; async fn add(self, _: &mut Self::Context, x: i32, y: i32) -> i32 { x + y } } #[derive(Clone)] -struct DoubleServer { - add_client: add::AddClient, - ghost: PhantomData, +struct DoubleServer { + add_client: add::AddClient, } -impl DoubleService for DoubleServer +impl DoubleService for DoubleServer where - Stub: AddStub + Clone + Send + Sync + 'static, - ClientCtx: From + Send + Sync + 'static, + Stub: AddStub + Clone + Send + Sync + 'static, { - type Context = context::Context; + type Context = context::DefaultContext; async fn double(self, _: &mut Self::Context, x: i32) -> Result { self.add_client - .add(&mut ClientCtx::from(context::current()), x, x) + .add(&mut context::current(), x, x) .await .map_err(|e| e.to_string()) } @@ -129,16 +126,18 @@ where Ok((listener, addr)) } -fn make_stub( - backends: [impl Transport>, Response> + Send + Sync + 'static; N], +fn make_stub( + backends: [impl Transport>, Response> + + Send + + Sync + + 'static; N], ) -> retry::Retry< impl Fn(&Result, u32) -> bool + Clone, - load_balance::RoundRobin, Resp, ClientCtx>>, + load_balance::RoundRobin, Resp, DefaultContext, DefaultContext>>, > where Req: RequestName + Send + Sync + 'static, Resp: Send + Sync + 'static, - ClientCtx: ExtractContext + From + Send + Sync + 'static, { let stub = load_balance::RoundRobin::new( backends @@ -193,7 +192,7 @@ async fn main() -> anyhow::Result<()> { .filter_map(|r| future::ready(r.ok())); let addr = double_listener.get_ref().local_addr(); let double_server = double_listener.map(BaseChannel::with_defaults).take(1); - let server = DoubleServer::<_, context::Context> { add_client, ghost: PhantomData }.serve(); + let server = DoubleServer { add_client }.serve(); tokio::spawn(spawn_incoming(double_server.execute(server))); let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; @@ -201,7 +200,10 @@ async fn main() -> anyhow::Result<()> { double::DoubleClient::new(client::Config::default(), to_double_server).spawn(); for _ in 1..=5 { - tracing::info!("{:?}", double_client.double(&mut context::current(), 1).await?); + tracing::info!( + "{:?}", + double_client.double(&mut context::current(), 1).await? + ); } tracer_provider.shutdown()?; diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 74031d969..ebd5db63b 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -9,6 +9,7 @@ mod in_flight_requests; pub mod stub; +use crate::context::{ExtractContext, SharedContext}; use crate::{ ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport, cancellations::{CanceledRequests, RequestCancellation, cancellations}, @@ -19,20 +20,12 @@ use futures::{prelude::*, ready, stream::Fuse, task::*}; use in_flight_requests::InFlightRequests; use pin_project::pin_project; use std::marker::PhantomData; -use std::{ - any::Any, - convert::TryFrom, - fmt, - pin::Pin, - sync::{ - Arc, - atomic::{AtomicUsize, Ordering}, - }, - time::SystemTime, -}; +use std::{any::Any, convert::TryFrom, fmt, pin::Pin, sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, +}, time::SystemTime}; use tokio::sync::{mpsc, oneshot}; use tracing::Span; -use crate::context::ExtractContext; /// Settings that control the behavior of the client. #[derive(Clone, Debug)] @@ -97,18 +90,18 @@ const _CHECK_USIZE: () = assert!( /// Handles communication from the client to request dispatch. #[derive(Debug)] -pub struct Channel { - to_dispatch: mpsc::Sender>, +pub struct Channel { + to_dispatch: mpsc::Sender>, /// Channel to send a cancel message to the dispatcher. cancellation: RequestCancellation, /// The ID to use for the next request to stage. next_request_id: Arc, ///TODO: Document - ghost: PhantomData, + ghost: PhantomData, } -impl Clone for Channel { +impl Clone for Channel { fn clone(&self) -> Self { Self { to_dispatch: self.to_dispatch.clone(), @@ -119,10 +112,11 @@ impl Clone for Channel { } } -impl Channel +impl Channel where Req: RequestName, - ClientCtx: ExtractContext, + ClientCtx: ExtractContext + Clone, + SharedCtx: context::SharedContext, { /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that /// resolves to the response. @@ -131,20 +125,26 @@ where skip(self, ctx, request), fields( rpc.trace_id = tracing::field::Empty, - rpc.deadline = %humantime::format_rfc3339(SystemTime::now() + ctx.extract().deadline.time_until()), + rpc.deadline = %humantime::format_rfc3339(SystemTime::now() + ctx.extract().deadline().time_until()), otel.kind = "client", otel.name = %request.name()) )] pub async fn call(&self, ctx: &mut ClientCtx, request: Req) -> Result { let span = Span::current(); let mut shared_context = ctx.extract(); - shared_context.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { + shared_context.set_trace_context(trace::Context::try_from(&span).unwrap_or_else(|_| { tracing::trace!( "OpenTelemetry subscriber not installed; making unsampled child context." ); - shared_context.trace_context.new_child() - }); - span.record("rpc.trace_id", tracing::field::display(shared_context.trace_id()), ); + shared_context.trace_context().new_child() + })); + span.record( + "rpc.trace_id", + tracing::field::display(shared_context.trace_context().trace_id), + ); + + ctx.update(shared_context); + let (response_completion, mut response) = oneshot::channel(); let request_id = u64::try_from(self.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); @@ -159,9 +159,10 @@ where cancellation: &self.cancellation, cancel: true, }; + self.to_dispatch .send(DispatchRequest { - ctx: shared_context, + ctx: ctx.clone(), // TODO: It would be best to forward the &mut ctx here, but unsure how to make the lifetimes work. span, request_id, request, @@ -172,7 +173,7 @@ where let (response_ctx, r) = response_guard.response().await?; - ctx.update(response_ctx); + ctx.update(response_ctx.extract()); Ok(r) } @@ -180,8 +181,8 @@ where /// A server response that is completed by request dispatch when the corresponding response /// arrives off the wire. -struct ResponseGuard<'a, Resp> { - response: &'a mut oneshot::Receiver>, +struct ResponseGuard<'a, Resp, SharedCtx> { + response: &'a mut oneshot::Receiver>, cancellation: &'a RequestCancellation, request_id: u64, cancel: bool, @@ -208,8 +209,8 @@ pub enum RpcError { Server(#[from] ServerError), } -impl ResponseGuard<'_, Resp> { - async fn response(mut self) -> Result<(context::Context, Resp), RpcError> { +impl ResponseGuard<'_, Resp, SharedCtx> { + async fn response(mut self) -> Result<(SharedCtx, Resp), RpcError> { let response = (&mut self.response).await; // Cancel drop logic once a response has been received. self.cancel = false; @@ -226,7 +227,7 @@ impl ResponseGuard<'_, Resp> { } // Cancels the request when dropped, if not already complete. -impl Drop for ResponseGuard<'_, Resp> { +impl Drop for ResponseGuard<'_, Resp, SharedCtx> { fn drop(&mut self) { // The receiver needs to be closed to handle the edge case that the request has not // yet been received by the dispatch task. It is possible for the cancel message to @@ -247,10 +248,13 @@ impl Drop for ResponseGuard<'_, Resp> { /// Returns a channel and dispatcher that manages the lifecycle of requests initiated by the /// channel. -pub fn new( +pub fn new( config: Config, transport: C, -) -> NewClient, RequestDispatch> +) -> NewClient< + Channel, + RequestDispatch, +> where C: Transport, Response>, { @@ -271,7 +275,7 @@ where in_flight_requests: InFlightRequests::default(), pending_requests, terminal_error: None, - ghost: PhantomData, + ghost: PhantomData }, } } @@ -281,16 +285,16 @@ where #[must_use] #[pin_project()] #[derive(Debug)] -pub struct RequestDispatch { +pub struct RequestDispatch { /// Writes requests to the wire and reads responses off the wire. #[pin] transport: Fuse, /// Requests waiting to be written to the wire. - pending_requests: mpsc::Receiver>, + pending_requests: mpsc::Receiver>, /// Requests that were dropped. canceled_requests: CanceledRequests, /// Requests already written to the wire that haven't yet received responses. - in_flight_requests: InFlightRequests, + in_flight_requests: InFlightRequests, /// Configures limits to prevent unlimited resource usage. config: Config, /// Produces errors that can be sent in response to any unprocessed requests at the time @@ -299,17 +303,18 @@ pub struct RequestDispatch { /// determined within the poll function. terminal_error: Option>, - ghost: PhantomData, + ghost: PhantomData, } -impl RequestDispatch +impl RequestDispatch where C: Transport, Response>, - ClientCtx: ExtractContext + From, + ClientCtx: ExtractContext + Clone, //TODO: We need to clone the ClientCtx to be able to send it through the dispatch channel which requires a 'static type, thus we need to move it. I feel this limitation can be lifted with some smart lifetime magic. + SharedCtx: SharedContext, { fn in_flight_requests<'a>( self: &'a mut Pin<&mut Self>, - ) -> &'a mut InFlightRequests { + ) -> &'a mut InFlightRequests { self.as_mut().project().in_flight_requests } @@ -326,7 +331,10 @@ where .map_err(|e| ChannelError::Ready(Arc::new(e))) } - fn start_send(self: &mut Pin<&mut Self>, message: ClientMessage) -> Result<(), C::Error> { + fn start_send( + self: &mut Pin<&mut Self>, + message: ClientMessage, + ) -> Result<(), C::Error> { self.transport_pin_mut().start_send(message) } @@ -354,7 +362,7 @@ where fn pending_requests_mut<'a>( self: &'a mut Pin<&mut Self>, - ) -> &'a mut mpsc::Receiver> { + ) -> &'a mut mpsc::Receiver> { self.as_mut().project().pending_requests } @@ -435,7 +443,7 @@ where fn poll_next_request( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll, ChannelError>>> { + ) -> Poll, ChannelError>>> { if self.in_flight_requests().len() >= self.config.max_in_flight_requests { tracing::debug!( "At in-flight request capacity ({}/{}).", @@ -529,17 +537,25 @@ where // Therefore, we can call write_request without fear of erroring due to a full // buffer. - let trace_context = ctx.trace_context; - let deadline = ctx.deadline; + let shared_context = ctx.extract(); + + let trace_context = shared_context.trace_context(); + let deadline = shared_context.deadline(); let request = ClientMessage::Request(Request { id: request_id, message: request, - context: ctx.into(), //TODO: <-- This will actually initialize an empty client context, and the transport will never see the original + context: ctx, }); self.in_flight_requests() - .insert_request(request_id, trace_context, deadline, span.clone(), response_completion) + .insert_request( + request_id, + trace_context, + deadline, + span.clone(), + response_completion, + ) .expect("Request IDs should be unique"); match self.start_send(request) { Ok(()) => tracing::debug!("SendRequest"), @@ -561,10 +577,11 @@ where self: &mut Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>>> { - let (trace_context, span, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) { - Some(triple) => triple, - None => return Poll::Ready(None), - }; + let (trace_context, span, request_id) = + match ready!(self.as_mut().poll_next_cancellation(cx)?) { + Some(triple) => triple, + None => return Poll::Ready(None), + }; let _entered = span.enter(); let cancel = ClientMessage::Cancel { @@ -584,7 +601,7 @@ where response .message .map_err(RpcError::Server) - .map(|m| (response.context.extract(), m)), + .map(|m| (response.context, m)), ) { let _entered = span.enter(); tracing::debug!("ReceiveResponse"); @@ -659,10 +676,12 @@ where } } -impl Future for RequestDispatch +impl Future + for RequestDispatch where C: Transport, Response>, - ClientCtx: ExtractContext + From, + ClientCtx: ExtractContext + Clone, + SharedCtx: context::SharedContext, { type Output = Result<(), ChannelError>; @@ -692,13 +711,12 @@ where /// A server-bound request sent from a [`Channel`] to request dispatch, which will then manage /// the lifecycle of the request. #[derive(Debug)] -struct DispatchRequest { - ///TODO: this should be a &mut ClientCtx - pub ctx: context::Context, +struct DispatchRequest { + pub ctx: ClientCtx, ///TODO: this should be a &mut ClientCtx pub span: Span, pub request_id: u64, pub request: Req, - pub response_completion: oneshot::Sender>, + pub response_completion: oneshot::Sender>, } #[cfg(test)] @@ -706,11 +724,12 @@ mod tests { use super::{ Channel, DispatchRequest, RequestDispatch, ResponseGuard, RpcError, cancellations, }; + use crate::context::DefaultContext; use crate::{ ChannelError, ClientMessage, Response, client::{Config, in_flight_requests::InFlightRequests}, context, - transport::{self, channel::UnboundedChannel} + transport::{self, channel::UnboundedChannel}, }; use assert_matches::assert_matches; use futures::{prelude::*, task::*}; @@ -741,7 +760,13 @@ mod tests { dispatch .in_flight_requests - .insert_request(0, context.trace_context, context.deadline, Span::current(), tx) + .insert_request( + 0, + context.trace_context, + context.deadline, + Span::current(), + tx, + ) .unwrap(); server_channel .send(Response { @@ -759,7 +784,7 @@ mod tests { async fn dispatch_response_cancels_on_drop() { let (cancellation, mut canceled_requests) = cancellations(); let (_, mut response) = oneshot::channel(); - drop(ResponseGuard:: { + drop(ResponseGuard:: { response: &mut response, cancellation: &cancellation, request_id: 3, @@ -774,8 +799,7 @@ mod tests { async fn dispatch_response_doesnt_cancel_after_complete() { let (cancellation, mut canceled_requests) = cancellations(); let (tx, mut response) = oneshot::channel(); - tx.send(Ok((context::current(), "well done"))) - .unwrap(); + tx.send(Ok((context::current(), "well done"))).unwrap(); // resp's drop() is run, but should not send a cancel message. ResponseGuard { response: &mut response, @@ -793,11 +817,11 @@ mod tests { #[tokio::test] async fn stage_request() { - let (mut dispatch, mut channel, _server_channel) = set_up::(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); - let _resp = send_request(&mut channel, "hi", tx, &mut rx).await; + let _resp = send_request(&mut channel, context::current(), "hi", tx, &mut rx).await; #[allow(unstable_name_collisions)] let req = dispatch.as_mut().poll_next_request(cx).ready(); @@ -815,7 +839,7 @@ mod tests { let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); - let _ = send_request(&mut channel, "hi", tx, &mut rx).await; + let _ = send_request(&mut channel, context::current(), "hi", tx, &mut rx).await; drop(channel); assert!(dispatch.as_mut().poll(cx).is_ready()); @@ -834,11 +858,11 @@ mod tests { #[allow(unstable_name_collisions)] #[tokio::test] async fn stage_request_response_future_dropped_is_canceled_before_sending() { - let (mut dispatch, mut channel, _server_channel) = set_up::(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); - let _ = send_request(&mut channel, "hi", tx, &mut rx).await; + let _ = send_request(&mut channel, context::current(), "hi", tx, &mut rx).await; // Drop the channel so polling returns none if no requests are currently ready. drop(channel); @@ -850,11 +874,11 @@ mod tests { #[allow(unstable_name_collisions)] #[tokio::test] async fn stage_request_response_future_dropped_is_canceled_after_sending() { - let (mut dispatch, mut channel, _server_channel) = set_up::(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); - let req = send_request(&mut channel, "hi", tx, &mut rx).await; + let req = send_request(&mut channel, context::current(), "hi", tx, &mut rx).await; assert!(dispatch.as_mut().pump_write(cx).ready().is_some()); assert!(!dispatch.in_flight_requests.is_empty()); @@ -871,14 +895,14 @@ mod tests { #[tokio::test] async fn stage_request_response_closed_skipped() { - let (mut dispatch, mut channel, _server_channel) = set_up::(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); // Test that a request future that's closed its receiver but not yet canceled its request -- // i.e. still in `drop fn` -- will cause the request to not be added to the in-flight request // map. - let resp = send_request(&mut channel, "hi", tx, &mut rx).await; + let resp = send_request(&mut channel, context::current(), "hi", tx, &mut rx).await; resp.response.close(); assert!(dispatch.as_mut().poll_next_request(cx).is_pending()); @@ -887,14 +911,15 @@ mod tests { #[tokio::test] async fn test_permit_before_transport_error() { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); - let (mut dispatch, mut channel, mut cx) = set_up_always_err::(TransportError::Flush); + let (mut dispatch, mut channel, mut cx) = + set_up_always_err::(TransportError::Flush); let (tx, mut rx) = oneshot::channel(); // reserve succeeds let permit = reserve_for_send(&mut channel, tx, &mut rx).await; // Since there's an outstanding permit, dispatch should not complete yet. assert_matches!(dispatch.as_mut().poll(&mut cx), Poll::Pending); - let resp = permit("hi"); + let resp = permit(context::current(), "hi"); // errors from both the dispatch future and the request assert_matches!(dispatch.as_mut().poll(&mut cx), Poll::Ready(Err(ChannelError::Flush(e))) if matches!(*e, TransportError::Flush)); @@ -904,20 +929,23 @@ mod tests { #[tokio::test] async fn test_shutdown() { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); - let (dispatch, channel, _server_channel) = set_up::(); + let (dispatch, channel, _server_channel) = set_up::(); drop(dispatch); // error on send - let resp = channel.call(&mut context::current(), "hi".to_string()).await; + let resp = channel + .call(&mut context::current(), "hi".to_string()) + .await; assert_matches!(resp, Err(RpcError::Shutdown)); } #[tokio::test] async fn test_transport_error_write() { let cause = TransportError::Write; - let (mut dispatch, mut channel, mut cx) = set_up_always_err::(cause); + let (mut dispatch, mut channel, mut cx) = + set_up_always_err::(cause); let (tx, mut rx) = oneshot::channel(); - let resp = send_request(&mut channel, "hi", tx, &mut rx).await; + let resp = send_request(&mut channel, context::current(), "hi", tx, &mut rx).await; assert!(dispatch.as_mut().poll(&mut cx).is_pending()); let res = resp.response().await; assert_matches!(res, Err(RpcError::Send(_))); @@ -937,9 +965,10 @@ mod tests { #[tokio::test] async fn test_transport_error_read() { let cause = TransportError::Read; - let (mut dispatch, mut channel, mut cx) = set_up_always_err::(cause); + let (mut dispatch, mut channel, mut cx) = + set_up_always_err::(cause); let (tx, mut rx) = oneshot::channel(); - let resp = send_request(&mut channel, "hi", tx, &mut rx).await; + let resp = send_request(&mut channel, context::current(), "hi", tx, &mut rx).await; assert_eq!( dispatch.as_mut().pump_write(&mut cx), Poll::Ready(Some(Ok(()))) @@ -954,7 +983,7 @@ mod tests { #[tokio::test] async fn test_transport_error_ready() { let cause = TransportError::Ready; - let (mut dispatch, _, mut cx) = set_up_always_err::(cause); + let (mut dispatch, _, mut cx) = set_up_always_err::(cause); assert_eq!( dispatch.as_mut().poll(&mut cx), Poll::Ready(Err(ChannelError::Ready(Arc::new(cause)))) @@ -964,7 +993,7 @@ mod tests { #[tokio::test] async fn test_transport_error_flush() { let cause = TransportError::Flush; - let (mut dispatch, _, mut cx) = set_up_always_err::(cause); + let (mut dispatch, _, mut cx) = set_up_always_err::(cause); assert_eq!( dispatch.as_mut().poll(&mut cx), Poll::Ready(Err(ChannelError::Flush(Arc::new(cause)))) @@ -974,7 +1003,8 @@ mod tests { #[tokio::test] async fn test_transport_error_close() { let cause = TransportError::Close; - let (mut dispatch, channel, mut cx) = set_up_always_err::(cause); + let (mut dispatch, channel, mut cx) = + set_up_always_err::(cause); drop(channel); assert_eq!( dispatch.as_mut().poll(&mut cx), @@ -983,17 +1013,28 @@ mod tests { } /// Sets up a RequestDispatch with a transport that always errors. - fn set_up_always_err( + fn set_up_always_err( cause: TransportError, ) -> ( - Pin>>>, - Channel, + Pin< + Box< + RequestDispatch< + String, + String, + ClientCtx, + SharedCtx, + AlwaysErrorTransport, + >, + >, + >, + Channel, Context<'static>, ) { let (to_dispatch, pending_requests) = mpsc::channel(1); let (cancellation, canceled_requests) = cancellations(); - let transport: AlwaysErrorTransport = AlwaysErrorTransport(cause, PhantomData); - let dispatch = Box::pin(RequestDispatch:: { + let transport: AlwaysErrorTransport = + AlwaysErrorTransport(cause, PhantomData); + let dispatch = Box::pin(RequestDispatch:: { transport: transport.fuse(), pending_requests, canceled_requests, @@ -1079,11 +1120,12 @@ mod tests { String, String, ClientCtx, + DefaultContext, UnboundedChannel, ClientMessage>, >, >, >, - Channel, + Channel, UnboundedChannel, Response>, ) { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); @@ -1092,7 +1134,7 @@ mod tests { let (cancellation, canceled_requests) = cancellations(); let (client_channel, server_channel) = transport::channel::unbounded(); - let dispatch = RequestDispatch:: { + let dispatch = RequestDispatch:: { transport: client_channel.fuse(), pending_requests, canceled_requests, @@ -1113,16 +1155,16 @@ mod tests { } async fn reserve_for_send<'a, ClientCtx>( - channel: &'a mut Channel, - response_completion: oneshot::Sender>, - response: &'a mut oneshot::Receiver>, - ) -> impl FnOnce(&str) -> ResponseGuard<'a, String> { + channel: &'a mut Channel, + response_completion: oneshot::Sender>, + response: &'a mut oneshot::Receiver>, + ) -> impl FnOnce(ClientCtx, &str) -> ResponseGuard<'a, String, DefaultContext> { let permit = channel.to_dispatch.reserve().await.unwrap(); - |request| { + |ctx, request| { let request_id = u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); let request = DispatchRequest { - ctx: context::current(), + ctx, span: Span::current(), request_id, request: request.to_string(), @@ -1139,15 +1181,16 @@ mod tests { } async fn send_request<'a, ClientCtx>( - channel: &'a mut Channel, + channel: &'a mut Channel, + context: ClientCtx, request: &str, - response_completion: oneshot::Sender>, - response: &'a mut oneshot::Receiver>, - ) -> ResponseGuard<'a, String> { + response_completion: oneshot::Sender>, + response: &'a mut oneshot::Receiver>, + ) -> ResponseGuard<'a, String, ClientCtx> { let request_id = - u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); + u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); let request = DispatchRequest { - ctx: context::current(), + ctx: context, span: Span::current(), request_id, request: request.to_string(), diff --git a/tarpc/src/client/in_flight_requests.rs b/tarpc/src/client/in_flight_requests.rs index d6424c564..ec71ad628 100644 --- a/tarpc/src/client/in_flight_requests.rs +++ b/tarpc/src/client/in_flight_requests.rs @@ -1,26 +1,26 @@ +use crate::client::RpcError; use crate::{ - context, trace, - util::{Compact, TimeUntil} + trace, + util::{Compact, TimeUntil}, }; use fnv::FnvHashMap; +use std::time::Instant; use std::{ collections::hash_map, task::{Context, Poll}, }; -use std::time::Instant; use tokio::sync::oneshot; use tokio_util::time::delay_queue::{self, DelayQueue}; use tracing::Span; -use crate::client::RpcError; /// Requests already written to the wire that haven't yet received responses. #[derive(Debug)] -pub struct InFlightRequests { - request_data: FnvHashMap>, +pub struct InFlightRequests { + request_data: FnvHashMap>, deadlines: DelayQueue, } -impl Default for InFlightRequests { +impl Default for InFlightRequests { fn default() -> Self { Self { request_data: Default::default(), @@ -30,10 +30,10 @@ impl Default for InFlightRequests { } #[derive(Debug)] -struct RequestData { +struct RequestData { ctx: trace::Context, span: Span, - response_completion: oneshot::Sender>, + response_completion: oneshot::Sender>, /// The key to remove the timer for the request's deadline. deadline_key: delay_queue::Key, } @@ -43,7 +43,7 @@ struct RequestData { #[derive(Debug)] pub struct AlreadyExistsError; -impl InFlightRequests { +impl InFlightRequests { /// Returns the number of in-flight requests. pub fn len(&self) -> usize { self.request_data.len() @@ -61,7 +61,7 @@ impl InFlightRequests { ctx: trace::Context, deadline: Instant, span: Span, - response_completion: oneshot::Sender>, + response_completion: oneshot::Sender>, ) -> Result<(), AlreadyExistsError> { match self.request_data.entry(request_id) { hash_map::Entry::Vacant(vacant) => { @@ -80,7 +80,11 @@ impl InFlightRequests { } /// Removes a request without aborting. Returns true if the request was found. - pub fn complete_request(&mut self, request_id: u64, result: Result<(context::Context, Res), RpcError>) -> Option { + pub fn complete_request( + &mut self, + request_id: u64, + result: Result<(SharedCtx, Res), RpcError>, + ) -> Option { if let Some(request_data) = self.request_data.remove(&request_id) { self.request_data.compact(0.1); self.deadlines.remove(&request_data.deadline_key); @@ -98,7 +102,7 @@ impl InFlightRequests { /// Returns Spans for all completes requests. pub fn complete_all_requests<'a>( &'a mut self, - mut result: impl FnMut() -> Result<(context::Context, Res), RpcError> + 'a, + mut result: impl FnMut() -> Result<(SharedCtx, Res), RpcError> + 'a, ) -> impl Iterator + 'a { self.deadlines.clear(); self.request_data.drain().map(move |(_, request_data)| { @@ -124,7 +128,7 @@ impl InFlightRequests { pub fn poll_expired( &mut self, cx: &mut Context, - expired_error: impl Fn() -> Result<(context::Context, Res), RpcError>, + expired_error: impl Fn() -> Result<(SharedCtx, Res), RpcError>, ) -> Poll> { self.deadlines.poll_expired(cx).map(|expired| { let request_id = expired?.into_inner(); diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index 5e473566c..8f49b31b3 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -28,14 +28,18 @@ pub trait Stub { type ClientCtx; /// Calls a remote service. - async fn call(&self, ctx: &mut Self::ClientCtx, request: Self::Req) - -> Result; + async fn call( + &self, + ctx: &mut Self::ClientCtx, + request: Self::Req, + ) -> Result; } -impl Stub for Channel +impl Stub for Channel where Req: RequestName, - ClientCtx: ExtractContext, + ClientCtx: ExtractContext + Clone, + SharedCtx: context::SharedContext, { type Req = Req; type Resp = Resp; @@ -53,7 +57,11 @@ where type Req = S::Req; type Resp = S::Resp; type ClientCtx = S::ServerCtx; - async fn call(&self, ctx: &mut Self::ClientCtx, req: Self::Req) -> Result { + async fn call( + &self, + ctx: &mut Self::ClientCtx, + req: Self::Req, + ) -> Result { self.clone().serve(ctx, req).await.map_err(RpcError::Server) } } diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs index eb605ecf9..5c6cc9aca 100644 --- a/tarpc/src/client/stub/load_balance.rs +++ b/tarpc/src/client/stub/load_balance.rs @@ -5,9 +5,7 @@ pub use round_robin::RoundRobin; /// Provides a stub that load-balances with a simple round-robin strategy. mod round_robin { - use crate::{ - client::{RpcError, stub}, - }; + use crate::client::{RpcError, stub}; use cycle::AtomicCycle; impl stub::Stub for RoundRobin @@ -98,9 +96,7 @@ mod round_robin { /// Each request is hashed, then mapped to a stub based on the hash. Equivalent requests will use /// the same stub. mod consistent_hash { - use crate::{ - client::{RpcError, stub} - }; + use crate::client::{RpcError, stub}; use std::{ collections::hash_map::RandomState, hash::{BuildHasher, Hash}, diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index d4a6611e0..cbd7b3b8f 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -23,7 +23,7 @@ use tracing_opentelemetry::OpenTelemetrySpanExt; /// be different for each request in scope. #[derive(Debug, Clone)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Context { +pub struct DefaultContext { /// When the client expects the request to be complete by. The server should cancel the request /// if it is not complete by this time. #[cfg_attr(feature = "serde1", serde(default = "ten_seconds_from_now"))] @@ -37,11 +37,101 @@ pub struct Context { pub trace_context: trace::Context, } -///TODO +/// A shared, on-the-wire request context. +/// +/// `SharedContext` defines the minimal interface required for contexts that are +/// transmitted between peers as part of an RPC call. +/// +/// Implementations of this trait represent the *wire-level* context: a +/// portable, serializable view of request-scoped metadata such as deadlines +/// and tracing information. Service implementations are free to define +/// richer context types as part of their contract, for example to implement: +/// +/// - ephemeral sessions +/// - authentication / authorization data +/// - cookie-like state +/// - other application-specific metadata +/// +/// while preserving a common core required by the RPC runtime. +pub trait SharedContext { + /// Returns the absolute deadline for the request. + /// + /// The deadline represents the latest instant at which the request + /// should still be processed. RPC runtimes and middleware may use this + /// value to enforce timeouts, cancel in-flight work, or reject requests + /// that have already expired. + fn deadline(&self) -> Instant; + + + /// Returns the distributed tracing context associated with the request. + /// + /// This context is propagated across RPC boundaries to enable + /// end-to-end request tracing and correlation. + //TODO: May want to remove this in the long run from the default context, may need https://github.com/rust-lang/rust/issues/144361 for that. + fn trace_context(&self) -> trace::Context; + + + /// Updates the distributed tracing context. + /// + /// This is typically used by middleware to attach spans, propagate + /// trace identifiers, or replace the tracing context after deserialization. + //TODO: May want to remove this in the long run from the default context, may need https://github.com/rust-lang/rust/issues/144361 for that. + fn set_trace_context(&mut self, trace_context: trace::Context); +} + +impl SharedContext for DefaultContext { + fn deadline(&self) -> Instant { + self.deadline + } + + fn trace_context(&self) -> trace::Context { + self.trace_context + } + + fn set_trace_context(&mut self, trace_context: trace::Context) { + self.trace_context = trace_context; + } +} + +/// Extracts or updates a wire-level shared context contained within a client or server context. +/// +/// `ExtractContext` defines a bidirectional mapping between an internal +/// context representation and a *shared* context type (`Ctx`) that is +/// suitable for serialization and transmission over the wire. +/// +/// The shared context typically represents the minimal, stable data +/// exchanged between the client and server, while the +/// implementing type may contain additional, local side only state or +/// a different internal structure. +/// +/// This trait is intentionally symmetric: +/// - [`extract`](Self::extract) converts from the internal representation +/// into the shared, serializable context. +/// - [`update`](Self::update) applies a shared context to the internal +/// representation, updating or reconstructing local state as needed. +/// +/// # Design notes +/// +/// Implementations are expected to be *lossy* or *lossless* depending on +/// the application’s needs. Any information not representable in `Ctx` +/// must be reconstructed, defaulted, or retained internally by the +/// implementation. +// TODO: Revisit this trait once try_as_dyn is stabilized, https://github.com/rust-lang/rust/issues/29661. pub trait ExtractContext { - ///TODO + /// Extracts the inner context from the internal state. + /// + /// This method is typically called before sending the context over + /// the wire, or just after receiving it. The returned value should contain + /// all information required by the remote side to reconstruct or update its own + /// local context. fn extract(&self) -> Ctx; - ///TODO + + /// Updates the internal state from an inner context value. + /// + /// This method is typically called after executing a request and before + /// sending the updated context over the wire. Implementations should apply + /// the provided value to their internal representation, updating any derived or + /// local-only state as necessary. fn update(&mut self, value: Ctx); } @@ -111,15 +201,15 @@ mod absolute_to_relative_time { } } -assert_impl_all!(Context: Send, Sync); +assert_impl_all!(DefaultContext: Send, Sync); fn ten_seconds_from_now() -> Instant { Instant::now() + Duration::from_secs(10) } /// Returns the context for the current request, or a default Context if no request is active. -pub fn current() -> Context { - Context::current() +pub fn current() -> DefaultContext { + DefaultContext::current() } #[derive(Clone)] @@ -131,7 +221,7 @@ impl Default for Deadline { } } -impl Context { +impl DefaultContext { /// Returns the context for the current request, or a default Context if no request is active. pub fn current() -> Self { let span = tracing::Span::current(); @@ -157,21 +247,21 @@ impl Context { pub(crate) trait SpanExt { /// Sets the given context on this span. Newly-created spans will be children of the given /// context's trace context. - fn set_context(&self, context: &Context); + fn set_context(&self, context: &T); } impl SpanExt for tracing::Span { - fn set_context(&self, context: &Context) { + fn set_context(&self, context: &T) { self.set_parent( opentelemetry::Context::new() .with_remote_span_context(opentelemetry::trace::SpanContext::new( - opentelemetry::trace::TraceId::from(context.trace_context.trace_id), - opentelemetry::trace::SpanId::from(context.trace_context.span_id), - opentelemetry::trace::TraceFlags::from(context.trace_context.sampling_decision), + opentelemetry::trace::TraceId::from(context.trace_context().trace_id), + opentelemetry::trace::SpanId::from(context.trace_context().span_id), + opentelemetry::trace::TraceFlags::from(context.trace_context().sampling_decision), true, opentelemetry::trace::TraceState::default(), )) - .with_value(Deadline(context.deadline)), + .with_value(Deadline(context.deadline())), ); } } diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 76d9a1815..cff903110 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -124,7 +124,7 @@ //! struct HelloServer; //! //! impl World for HelloServer { -//! type Context = context::Context; +//! type Context = context::DefaultContext; //! // Each defined rpc generates an async fn that serves the RPC //! async fn hello(self, _: &mut Self::Context, name: String) -> String { //! format!("Hello, {name}!") @@ -158,7 +158,7 @@ //! # #[derive(Clone)] //! # struct HelloServer; //! # impl World for HelloServer { -//! # type Context = context::Context; +//! # type Context = context::DefaultContext; //! # //! # // Each defined rpc generates an async fn that serves the RPC //! # async fn hello(self, _: &mut Self::Context, name: String) -> String { @@ -181,7 +181,7 @@ //! //! // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` //! // that takes a config and any Transport as input. -//! let mut client = WorldClient::::new(client::Config::default(), client_transport).spawn(); +//! let mut client = WorldClient::new(client::Config::default(), client_transport).spawn(); //! //! // The client has an RPC method for each RPC defined in the annotated trait. It takes the same //! // args as defined, with the addition of a Context, which is always the first arg. The Context diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 7e08db475..f4553450a 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -6,10 +6,11 @@ //! Provides a server that concurrently handles many connections sending multiplexed requests. +use crate::context::{ExtractContext, SharedContext}; use crate::{ ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport, cancellations::{CanceledRequests, RequestCancellation, cancellations}, - context::{self, SpanExt}, + context::SpanExt, trace, util::TimeUntil, }; @@ -27,7 +28,6 @@ use std::{ convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin, sync::Arc, time::SystemTime, }; use tracing::{Span, info_span, instrument::Instrument}; -use crate::context::ExtractContext; mod in_flight_requests; pub mod request_hook; @@ -59,10 +59,14 @@ impl Default for Config { impl Config { /// Returns a channel backed by `transport` and configured with `self`. - pub fn channel(self, transport: T) -> BaseChannel + pub fn channel( + self, + transport: T, + ) -> BaseChannel where T: Transport, ClientMessage>, - ServerCtx: ExtractContext, + ServerCtx: ExtractContext, + SharedCtx: SharedContext { BaseChannel::new(self, transport) } @@ -81,7 +85,11 @@ pub trait Serve { type Resp; /// Responds to a single request. - async fn serve(self, ctx: &mut Self::ServerCtx, req: Self::Req) -> Result; + async fn serve( + self, + ctx: &mut Self::ServerCtx, + req: Self::Req, + ) -> Result; } /// A Serve wrapper around a Fn. @@ -110,7 +118,10 @@ impl Copy for ServeFn where F: pub fn serve(f: F) -> ServeFn where // This should be -> impl Future<...>, but there is no syntax to express the 'a lifetime. - for<'a> F: FnOnce(&'a mut ServerCtx, Req) -> Pin> + 'a + Send>>, + for<'a> F: FnOnce( + &'a mut ServerCtx, + Req, + ) -> Pin> + 'a + Send>>, { ServeFn { f, @@ -147,7 +158,7 @@ where /// messages. Instead, it internally handles them by cancelling corresponding requests (removing /// the corresponding in-flight requests and aborting their handlers). #[pin_project] -pub struct BaseChannel { +pub struct BaseChannel { config: Config, /// Writes responses to the wire and reads requests off the wire. #[pin] @@ -160,13 +171,14 @@ pub struct BaseChannel { /// Holds data necessary to clean up in-flight requests. in_flight_requests: InFlightRequests, /// Types the request and response. - ghost: PhantomData<(Req, Resp, ServeCtx)>, + ghost: PhantomData<(Req, Resp, ServerCtx, SharedCtx)>, } -impl BaseChannel +impl BaseChannel where T: Transport, ClientMessage>, - ServerCtx: ExtractContext, + ServerCtx: ExtractContext, + SharedCtx: SharedContext, { /// Creates a new channel backed by `transport` and configured with `config`. pub fn new(config: Config, transport: T) -> Self { @@ -217,24 +229,24 @@ where let mut shared_context = request.context.extract(); let span = info_span!( "RPC", - rpc.trace_id = %shared_context.trace_id(), - rpc.deadline = %humantime::format_rfc3339(SystemTime::now() + shared_context.deadline.time_until()), + rpc.trace_id = %shared_context.trace_context().trace_id, + rpc.deadline = %humantime::format_rfc3339(SystemTime::now() + shared_context.deadline().time_until()), otel.kind = "server", otel.name = tracing::field::Empty, ); span.set_context(&shared_context); - shared_context.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { + shared_context.set_trace_context(trace::Context::try_from(&span).unwrap_or_else(|_| { tracing::trace!( "OpenTelemetry subscriber not installed; making unsampled \ child context." ); - shared_context.trace_context.new_child() - }); + shared_context.trace_context().new_child() + })); let entered = span.enter(); tracing::debug!("ReceiveRequest"); let start = self.in_flight_requests_mut().start_request( request.id, - shared_context.deadline, + shared_context.deadline(), span.clone(), ); match start { @@ -259,7 +271,7 @@ where } } -impl fmt::Debug for BaseChannel { +impl fmt::Debug for BaseChannel { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "BaseChannel") } @@ -306,7 +318,10 @@ pub struct TrackedRequest { /// created by [`BaseChannel`]. pub trait Channel where - Self: Transport::Resp>, TrackedRequest::Req>>, + Self: Transport< + Response::Resp>, + TrackedRequest::Req>, + >, { /// Type of request item. type Req; @@ -434,10 +449,11 @@ where } } -impl Stream for BaseChannel +impl Stream for BaseChannel where T: Transport, ClientMessage>, - ServerCtx: ExtractContext, + ServerCtx: ExtractContext, + SharedCtx: SharedContext, { type Item = Result, ChannelError>; @@ -542,11 +558,13 @@ where } } -impl Sink> for BaseChannel +impl Sink> + for BaseChannel where T: Transport, ClientMessage>, T::Error: Error, - ServerCtx: ExtractContext, + ServerCtx: ExtractContext, + SharedCtx: SharedContext, { type Error = ChannelError; @@ -557,7 +575,10 @@ where .map_err(|e| ChannelError::Ready(Arc::new(e))) } - fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { + fn start_send( + mut self: Pin<&mut Self>, + response: Response, + ) -> Result<(), Self::Error> { if let Some(span) = self .in_flight_requests_mut() .remove_request(response.request_id) @@ -590,16 +611,17 @@ where } } -impl AsRef for BaseChannel { +impl AsRef for BaseChannel { fn as_ref(&self) -> &T { self.transport.get_ref() } } -impl Channel for BaseChannel +impl Channel for BaseChannel where T: Transport, ClientMessage>, - ServerCtx: ExtractContext, + ServerCtx: ExtractContext, + SharedCtx: SharedContext, { type Req = Req; type Resp = Resp; @@ -849,7 +871,7 @@ impl InFlightRequest { /// /// 1. The channel that yielded this request receives a [cancellation /// message](ClientMessage::Cancel) for this request. - /// 2. The request [deadline](crate::context::Context::deadline) is reached. + /// 2. The request [deadline](crate::context::DefaultContext::deadline) is reached. /// 3. The service function completes. /// /// If the returned Future is dropped before completion, a cancellation message will be sent to @@ -983,7 +1005,7 @@ mod tests { request_hook::{AfterRequest, BeforeRequest, RequestHook}, serve, }; - use crate::context::{ExtractContext}; + use crate::context::ExtractContext; use crate::{ ClientMessage, Request, Response, ServerError, context, trace, transport::channel::{self, UnboundedChannel}, @@ -1003,16 +1025,50 @@ mod tests { }; fn test_channel() -> ( - Pin, Response>, context::Context>>>, - UnboundedChannel, ClientMessage>, + Pin< + Box< + BaseChannel< + Req, + Resp, + UnboundedChannel< + ClientMessage, + Response, + >, + context::DefaultContext, + context::DefaultContext, + >, + >, + >, + UnboundedChannel< + Response, + ClientMessage, + >, ) { let (tx, rx) = crate::transport::channel::unbounded(); (Box::pin(BaseChannel::new(Config::default(), rx)), tx) } fn test_requests() -> ( - Pin, Response>, context::Context>>>>, - UnboundedChannel, ClientMessage>, + Pin< + Box< + Requests< + BaseChannel< + Req, + Resp, + UnboundedChannel< + ClientMessage, + Response, + >, + context::DefaultContext, + context::DefaultContext, + >, + >, + >, + >, + UnboundedChannel< + Response, + ClientMessage, + >, ) { let (tx, rx) = crate::transport::channel::unbounded(); ( @@ -1024,8 +1080,26 @@ mod tests { fn test_bounded_requests( capacity: usize, ) -> ( - Pin, Response>, context::Context>>>>, - channel::Channel, ClientMessage>, + Pin< + Box< + Requests< + BaseChannel< + Req, + Resp, + channel::Channel< + ClientMessage, + Response, + >, + context::DefaultContext, + context::DefaultContext, + >, + >, + >, + >, + channel::Channel< + Response, + ClientMessage, + >, ) { let (tx, rx) = crate::transport::channel::bounded(capacity); // Add 1 because capacity 0 is not supported (but is supported by transport::channel::bounded). @@ -1035,7 +1109,7 @@ mod tests { (Box::pin(BaseChannel::new(config, rx).requests()), tx) } - fn fake_request(req: Req) -> ClientMessage { + fn fake_request(req: Req) -> ClientMessage { ClientMessage::Request(Request { context: context::current(), id: 0, @@ -1060,13 +1134,9 @@ mod tests { struct SetDeadline(Instant); impl BeforeRequest for SetDeadline where - ServerCtx: ExtractContext, + ServerCtx: ExtractContext, { - async fn before( - &mut self, - ctx: &mut ServerCtx, - _: &Req - ) -> Result<(), ServerError> { + async fn before(&mut self, ctx: &mut ServerCtx, _: &Req) -> Result<(), ServerError> { let mut inner = ctx.extract(); inner.deadline = self.0; ctx.update(inner); @@ -1077,10 +1147,13 @@ mod tests { let some_time = Instant::now() + Duration::from_secs(37); let some_other_time = Instant::now() + Duration::from_secs(83); - let serve = serve(move |ctx: &mut context::Context, i| async move { - assert_eq!(ctx.deadline, some_time); - Ok(i) - }.boxed()); + let serve = serve(move |ctx: &mut context::DefaultContext, i| { + async move { + assert_eq!(ctx.deadline, some_time); + Ok(i) + } + .boxed() + }); let deadline_hook = serve.before(SetDeadline(some_time)); let mut ctx = context::current(); ctx.deadline = some_other_time; @@ -1103,11 +1176,7 @@ mod tests { } } impl BeforeRequest for PrintLatency { - async fn before( - &mut self, - _: &mut ServerCtx, - _: &Req - ) -> Result<(), ServerError> { + async fn before(&mut self, _: &mut ServerCtx, _: &Req) -> Result<(), ServerError> { self.start = Instant::now(); Ok(()) } @@ -1118,7 +1187,7 @@ mod tests { } } - let serve = serve(move |_: &mut context::Context, i| async move { Ok(i) }.boxed()); + let serve = serve(move |_: &mut context::DefaultContext, i| async move { Ok(i) }.boxed()); serve .before_and_after(PrintLatency::new()) .serve(&mut context::current(), 7) @@ -1129,7 +1198,7 @@ mod tests { #[tokio::test] async fn serve_before_error_aborts_request() -> anyhow::Result<()> { let serve = serve(|_, _| async { panic!("Shouldn't get here") }.boxed()); - let deadline_hook = serve.before(|_: &mut context::Context, _: &i32| async { + let deadline_hook = serve.before(|_: &mut context::DefaultContext, _: &i32| async { Err(ServerError::new(io::ErrorKind::Other, "oops".into())) }); let resp: Result = deadline_hook.serve(&mut context::current(), 7).await; @@ -1338,7 +1407,9 @@ mod tests { Poll::Ready(Some(Ok(request))) => request, result => panic!("Unexpected result: {result:?}"), }; - request.execute(serve(|_, _| async { Ok(()) }.boxed())).await; + request + .execute(serve(|_, _| async { Ok(()) }.boxed())) + .await; assert!( requests .as_mut() diff --git a/tarpc/src/server/limits/requests_per_channel.rs b/tarpc/src/server/limits/requests_per_channel.rs index 527cb6f98..3fa81e580 100644 --- a/tarpc/src/server/limits/requests_per_channel.rs +++ b/tarpc/src/server/limits/requests_per_channel.rs @@ -180,6 +180,7 @@ where mod tests { use super::*; + use crate::context; use crate::server::{ TrackedRequest, testing::{self, FakeChannel, PollExt}, @@ -190,7 +191,6 @@ mod tests { time::{Duration, Instant}, }; use tracing::Span; - use crate::context; #[tokio::test] async fn throttler_in_flight_requests() { @@ -270,8 +270,10 @@ mod tests { ghost: PhantomData In>, } impl PendingSink<(), ()> { - pub fn default() - -> PendingSink>, Response, > { + pub fn default() -> PendingSink< + io::Result>, + Response, + > { PendingSink { ghost: PhantomData } } } @@ -297,11 +299,15 @@ mod tests { } } impl Channel - for PendingSink>, Response> { + for PendingSink< + io::Result>, + Response, + > + { type Req = Req; type Resp = Resp; type Transport = (); - type ServerCtx = context::Context; + type ServerCtx = context::DefaultContext; fn config(&self) -> &Config { unimplemented!() } diff --git a/tarpc/src/server/request_hook.rs b/tarpc/src/server/request_hook.rs index cce5998ee..afc9ad25a 100644 --- a/tarpc/src/server/request_hook.rs +++ b/tarpc/src/server/request_hook.rs @@ -48,7 +48,7 @@ pub trait RequestHook: Serve { /// use std::io; /// /// let serve = serve(|_ctx, i| async move { Ok(i + 1) }.boxed()) - /// .before(|_ctx: &mut context::Context, req: &i32| { + /// .before(|_ctx: &mut context::DefaultContext, req: &i32| { /// future::ready( /// if *req == 1 { /// Err(ServerError::new( @@ -95,7 +95,7 @@ pub trait RequestHook: Serve { /// Ok(i + 1) /// } /// }.boxed()) - /// .after(|_ctx: &mut context::Context, resp: &mut Result| { + /// .after(|_ctx: &mut context::DefaultContext, resp: &mut Result| { /// if let Err(e) = resp { /// eprintln!("server error: {e:?}"); /// } diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs index 13fc18509..adfac8e79 100644 --- a/tarpc/src/server/request_hook/before.rs +++ b/tarpc/src/server/request_hook/before.rs @@ -80,7 +80,11 @@ impl Clone for HookThenServe HookThenServe { pub(crate) fn new(serve: Serv, hook: Hook) -> Self { - Self { serve, hook, ghost: PhantomData } + Self { + serve, + hook, + ghost: PhantomData, + } } } @@ -93,11 +97,7 @@ where type Req = Serv::Req; type Resp = Serv::Resp; - async fn serve( - self, - ctx: &mut ServerCtx, - req: Self::Req - ) -> Result { + async fn serve(self, ctx: &mut ServerCtx, req: Self::Req) -> Result { let HookThenServe { serve, mut hook, .. } = self; diff --git a/tarpc/src/server/request_hook/before_and_after.rs b/tarpc/src/server/request_hook/before_and_after.rs index 934d82ad5..f3653a513 100644 --- a/tarpc/src/server/request_hook/before_and_after.rs +++ b/tarpc/src/server/request_hook/before_and_after.rs @@ -39,7 +39,8 @@ impl Clone } } -impl Serve for HookThenServeThenHook +impl Serve + for HookThenServeThenHook where Req: RequestName, Serv: Serve, diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index 39eabdaf5..bbb396c45 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -38,7 +38,8 @@ where } } -impl Sink> for FakeChannel> +impl Sink> + for FakeChannel> { type Error = io::Error; @@ -46,7 +47,10 @@ impl Sink> for FakeChannel, response: Response) -> Result<(), Self::Error> { + fn start_send( + mut self: Pin<&mut Self>, + response: Response, + ) -> Result<(), Self::Error> { self.as_mut() .project() .in_flight_requests @@ -66,14 +70,18 @@ impl Sink> for FakeChannel Channel for FakeChannel>, Response> +impl Channel + for FakeChannel< + io::Result>, + Response, + > where Req: Unpin, { type Req = Req; type Resp = Resp; type Transport = (); - type ServerCtx = context::Context; + type ServerCtx = context::DefaultContext; fn config(&self) -> &Config { &self.config @@ -88,14 +96,18 @@ where } } -impl FakeChannel>, Response> +impl + FakeChannel< + io::Result>, + Response, + > { pub fn push_req(&mut self, id: u64, message: Req) { let (_, abort_registration) = futures::future::AbortHandle::new_pair(); let (request_cancellation, _) = cancellations(); self.stream.push_back(Ok(TrackedRequest { request: Request { - context: context::Context { + context: context::DefaultContext { deadline: Instant::now(), trace_context: Default::default(), }, @@ -114,8 +126,10 @@ impl FakeChannel>, R } impl FakeChannel<(), ()> { - pub fn default() -> FakeChannel>, Response> - { + pub fn default() -> FakeChannel< + io::Result>, + Response, + > { let (request_cancellation, canceled_requests) = cancellations(); FakeChannel { stream: Default::default(), diff --git a/tarpc/src/trace.rs b/tarpc/src/trace.rs index d172dc2e2..6b00e7f71 100644 --- a/tarpc/src/trace.rs +++ b/tarpc/src/trace.rs @@ -66,13 +66,14 @@ pub struct SpanId(u64); /// dependencies. On the other hand, if an upstream process has chosen to sample this trace, then /// the downstream samplers are expected to respect that decision and also sample the trace. /// Otherwise, the full trace would not be able to be reconstructed reliably. -#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] +#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Default)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] #[repr(u8)] pub enum SamplingDecision { /// The associated span was sampled by its creating process. Child spans must also be sampled. Sampled, /// The associated span was not sampled by its creating process. + #[default] Unsampled, } @@ -203,12 +204,6 @@ impl From<&opentelemetry::trace::SpanContext> for SamplingDecision { } } -impl Default for SamplingDecision { - fn default() -> Self { - Self::Unsampled - } -} - /// Returned when a [`Context`] cannot be constructed from a [`Span`](tracing::Span). #[derive(Debug)] pub struct NoActiveSpan; diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 35c81fb1e..658925e6d 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -191,14 +191,19 @@ mod tests { tokio::spawn( stream::once(future::ready(server_channel)) .map(BaseChannel::with_defaults) - .execute(serve(|_ctx: &mut context::Context, request: String| async move { - request.parse::().map_err(|_| { - ServerError::new( - io::ErrorKind::InvalidInput, - format!("{request:?} is not an int"), - ) - }) - }.boxed())) + .execute(serve( + |_ctx: &mut context::DefaultContext, request: String| { + async move { + request.parse::().map_err(|_| { + ServerError::new( + io::ErrorKind::InvalidInput, + format!("{request:?} is not an int"), + ) + }) + } + .boxed() + }, + )) .for_each(|channel| async move { tokio::spawn(channel.for_each(|response| response)); }), diff --git a/tarpc/tests/compile_fail/must_use_request_dispatch.rs b/tarpc/tests/compile_fail/must_use_request_dispatch.rs index 812fc4ee7..9498d02b2 100644 --- a/tarpc/tests/compile_fail/must_use_request_dispatch.rs +++ b/tarpc/tests/compile_fail/must_use_request_dispatch.rs @@ -1,5 +1,5 @@ use tarpc::client; -use tarpc::context::Context; +use tarpc::context::DefaultContext; #[tarpc::service] trait World { async fn hello(name: String) -> String; @@ -10,6 +10,6 @@ fn main() { #[deny(unused_must_use)] { - WorldClient::::new(client::Config::default(), client_transport).dispatch; + WorldClient::::new(client::Config::default(), client_transport).dispatch; } } diff --git a/tarpc/tests/compile_fail/must_use_request_dispatch.stderr b/tarpc/tests/compile_fail/must_use_request_dispatch.stderr index 4fe34df5f..93e6aef45 100644 --- a/tarpc/tests/compile_fail/must_use_request_dispatch.stderr +++ b/tarpc/tests/compile_fail/must_use_request_dispatch.stderr @@ -1,8 +1,8 @@ error: unused `RequestDispatch` that must be used --> tests/compile_fail/must_use_request_dispatch.rs:13:9 | -13 | WorldClient::::new(client::Config::default(), client_transport).dispatch; - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +13 | WorldClient::::new(client::Config::default(), client_transport).dispatch; + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | note: the lint level is defined here --> tests/compile_fail/must_use_request_dispatch.rs:11:12 @@ -11,5 +11,5 @@ note: the lint level is defined here | ^^^^^^^^^^^^^^^ help: use `let _ = ...` to ignore the resulting value | -13 | let _ = WorldClient::::new(client::Config::default(), client_transport).dispatch; +13 | let _ = WorldClient::::new(client::Config::default(), client_transport).dispatch; | +++++++ diff --git a/tarpc/tests/compile_fail/serde1/opt_out_serde.stderr b/tarpc/tests/compile_fail/serde1/opt_out_serde.stderr index bdb46999c..630924f71 100644 --- a/tarpc/tests/compile_fail/serde1/opt_out_serde.stderr +++ b/tarpc/tests/compile_fail/serde1/opt_out_serde.stderr @@ -2,10 +2,15 @@ error[E0277]: the trait bound `FooRequest: serde::Serialize` is not satisfied --> tests/compile_fail/serde1/opt_out_serde.rs:12:40 | 12 | tarpc::serde::Serialize::serialize(&x, f); - | ---------------------------------- ^^ the trait `Serialize` is not implemented for `FooRequest` + | ---------------------------------- ^^ unsatisfied trait bound | | | required by a bound introduced by this call | +help: the trait `Serialize` is not implemented for `FooRequest` + --> tests/compile_fail/serde1/opt_out_serde.rs:5:1 + | + 5 | #[tarpc::service(derive_serde = false)] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ = note: for local types consider adding `#[derive(serde::Serialize)]` to your `FooRequest` type = note: for types from other crates check whether the crate offers a `serde` feature flag = help: the following other types implement trait `Serialize`: @@ -18,3 +23,4 @@ error[E0277]: the trait bound `FooRequest: serde::Serialize` is not satisfied (T0, T1, T2, T3) (T0, T1, T2, T3, T4) and $N others + = note: this error originates in the attribute macro `tarpc::service` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 5a5b2f8e7..ac4d17027 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -22,7 +22,7 @@ pub trait ColorProtocol { struct ColorServer; impl ColorProtocol for ColorServer { - type Context = context::Context; + type Context = context::DefaultContext; async fn get_opposite_color(self, _: &mut Self::Context, color: TestData) -> TestData { match color { TestData::White => TestData::Black, diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 559521414..28157b25f 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -8,7 +8,6 @@ use tarpc::{ client::{self}, context, server::{BaseChannel, Channel, incoming::Incoming}, - transport, transport::channel, }; use tokio::join; @@ -23,7 +22,7 @@ trait Service { struct Server; impl Service for Server { - type Context = context::Context; + type Context = context::DefaultContext; async fn add(self, _: &mut Self::Context, x: i32, y: i32) -> i32 { x + y } @@ -40,7 +39,9 @@ async fn sequential() { let channel = BaseChannel::with_defaults(rx); tokio::spawn( channel - .execute(tarpc::server::serve(|_, i: u32| async move { Ok(i + 1) }.boxed())) + .execute(tarpc::server::serve(|_, i: u32| { + async move { Ok(i + 1) }.boxed() + })) .for_each(|response| response), ); assert_eq!(client.call(&mut context::current(), 1).await.unwrap(), 2); @@ -57,7 +58,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { struct LoopServer; impl Loop for LoopServer { - type Context = context::Context; + type Context = context::DefaultContext; async fn r#loop(self, _: &mut Self::Context) { loop { futures::pending!(); @@ -260,7 +261,7 @@ async fn counter() -> anyhow::Result<()> { struct CountService(u32); impl Counter for &mut CountService { - type Context = context::Context; + type Context = context::DefaultContext; async fn count(self, _: &mut Self::Context) -> u32 { self.0 += 1; self.0 From 067e3aeb616af3a320df7c5588a4a81d15eab9fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Fri, 26 Dec 2025 13:03:51 +0100 Subject: [PATCH 25/26] run cargo fmt and clippy. Rename CustomContext to CustomSharedContext --- tarpc/examples/custom_context.rs | 435 ++++++++++++++++++------------- tarpc/src/client.rs | 20 +- tarpc/src/context.rs | 8 +- tarpc/src/server.rs | 10 +- 4 files changed, 284 insertions(+), 189 deletions(-) diff --git a/tarpc/examples/custom_context.rs b/tarpc/examples/custom_context.rs index 64978b063..19d490de5 100644 --- a/tarpc/examples/custom_context.rs +++ b/tarpc/examples/custom_context.rs @@ -5,90 +5,94 @@ // https://opensource.org/licenses/MIT. #![deny(warnings, unused, dead_code)] +use futures::{FutureExt, SinkExt, StreamExt, TryStreamExt}; +use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::ops::Add; use std::sync::Arc; use std::time::{Duration, Instant}; -use futures::{SinkExt, TryStreamExt, StreamExt, FutureExt}; -use serde::{Deserialize, Serialize}; -use tokio::sync::Mutex; -use tarpc::{client, server::{self, Channel}, trace, ClientMessage, Request, Response, ServerError, Transport}; use tarpc::context::{ExtractContext, SharedContext}; use tarpc::server::request_hook::{AfterRequest, BeforeRequest, RequestHook}; use tarpc::transport::channel::UnboundedChannel; +use tarpc::{ + ClientMessage, Request, Response, ServerError, Transport, client, + server::{self, Channel}, + trace, +}; +use tokio::sync::Mutex; #[derive(Serialize, Deserialize, Clone)] -struct CustomContext { - #[serde(with = "absolute_to_relative_time")] - pub deadline: Instant, - pub trace_context: trace::Context, - pub session_id: Option +struct CustomSharedContext { + #[serde(with = "absolute_to_relative_time")] + pub deadline: Instant, + pub trace_context: trace::Context, + pub session_id: Option, } -impl SharedContext for CustomContext { - fn deadline(&self) -> Instant { - self.deadline - } +impl SharedContext for CustomSharedContext { + fn deadline(&self) -> Instant { + self.deadline + } - fn trace_context(&self) -> trace::Context { - self.trace_context - } + fn trace_context(&self) -> trace::Context { + self.trace_context + } - fn set_trace_context(&mut self, trace_context: trace::Context) { - self.trace_context = trace_context; - } + fn set_trace_context(&mut self, trace_context: trace::Context) { + self.trace_context = trace_context; + } } #[derive(Clone, Debug)] struct ClientContext { - pub session_id: Option, - pub delay_sending_by_seconds: u32, + pub session_id: Option, + pub delay_sending_by_seconds: u32, } struct ServerContext { - pub deadline: Instant, - pub trace_context: trace::Context, - pub session_id: Option, - pub balance: u64, + pub deadline: Instant, + pub trace_context: trace::Context, + pub session_id: Option, + pub balance: u64, } -impl ExtractContext for ClientContext { - fn extract(&self) -> CustomContext { - CustomContext { - deadline: Instant::now().add(Duration::from_secs(60)), - trace_context: Default::default(), - session_id: self.session_id, +impl ExtractContext for ClientContext { + fn extract(&self) -> CustomSharedContext { + CustomSharedContext { + deadline: Instant::now().add(Duration::from_secs(60)), + trace_context: Default::default(), + session_id: self.session_id, + } } - } - fn update(&mut self, value: CustomContext) { - self.session_id = value.session_id; - } + fn update(&mut self, value: CustomSharedContext) { + self.session_id = value.session_id; + } } -impl ExtractContext for ServerContext { - fn extract(&self) -> CustomContext { - CustomContext { - deadline: self.deadline, - trace_context: self.trace_context, - session_id: self.session_id, +impl ExtractContext for ServerContext { + fn extract(&self) -> CustomSharedContext { + CustomSharedContext { + deadline: self.deadline, + trace_context: self.trace_context, + session_id: self.session_id, + } } - } - fn update(&mut self, value: CustomContext) { - self.deadline = value.deadline; - self.trace_context = value.trace_context; - self.session_id = value.session_id; - } + fn update(&mut self, value: CustomSharedContext) { + self.deadline = value.deadline; + self.trace_context = value.trace_context; + self.session_id = value.session_id; + } } /// This is the service definition. It looks a lot like a trait definition. /// It defines one RPC, hello, which takes one arg, name, and returns a String. #[tarpc::service(shared_context = "CustomContext")] pub trait World { - async fn create_session() -> (); - async fn increase_balance(credits: u32) -> (); - async fn hello(name: String) -> Result; + async fn create_session() -> (); + async fn increase_balance(credits: u32) -> (); + async fn hello(name: String) -> Result; } /// This is the type that implements the generated World trait. It is the business logic @@ -97,166 +101,245 @@ pub trait World { struct HelloServer; impl World for HelloServer { - type Context = ServerContext; + type Context = ServerContext; - async fn create_session(self, ctx: &mut Self::Context) -> () { - ctx.session_id = Some(42); - ctx.balance = 0; - } - - async fn increase_balance(self, ctx: &mut Self::Context, credits: u32) -> () { - ctx.balance = ctx.balance + credits as u64; - } - - async fn hello(self, ctx: &mut Self::Context, name: String) -> Result { - if ctx.session_id != Some(42) { - Err("Session not yet initialized!")? + async fn create_session(self, ctx: &mut Self::Context) -> () { + ctx.session_id = Some(42); + ctx.balance = 0; } - if ctx.balance == 0 { - Err("Give me more money")? + async fn increase_balance(self, ctx: &mut Self::Context, credits: u32) -> () { + ctx.balance = ctx.balance + credits as u64; } - ctx.balance = ctx.balance - 1; + async fn hello(self, ctx: &mut Self::Context, name: String) -> Result { + if ctx.session_id != Some(42) { + Err("Session not yet initialized!")? + } + + if ctx.balance == 0 { + Err("Give me more money")? + } - Ok(format!("Hello, {name}!")) - } + ctx.balance = ctx.balance - 1; + + Ok(format!("Hello, {name}!")) + } } async fn spawn(fut: impl Future + Send + 'static) { - tokio::spawn(fut); + tokio::spawn(fut); } - - - #[derive(Clone)] struct SessionHook { - balances: Arc>> + balances: Arc>>, } impl BeforeRequest for SessionHook { - async fn before(&mut self, ctx: &mut ServerContext, _req: &WorldRequest) -> Result<(), ServerError> { - if let Some(id) = ctx.session_id { - let balances = self.balances.lock().await; - ctx.balance = *balances.get(&id).unwrap_or(&0u64) + async fn before( + &mut self, + ctx: &mut ServerContext, + _req: &WorldRequest, + ) -> Result<(), ServerError> { + if let Some(id) = ctx.session_id { + let balances = self.balances.lock().await; + ctx.balance = *balances.get(&id).unwrap_or(&0u64) + } + + Ok(()) } - - Ok(()) - } } impl AfterRequest for SessionHook { - async fn after(&mut self, ctx: &mut ServerContext, resp: &mut Result) { - if resp.is_ok() && let Some(id) = ctx.session_id { - let mut balances = self.balances.lock().await; - - let b = balances.entry(id).or_insert(0); - *b = ctx.balance; + async fn after( + &mut self, + ctx: &mut ServerContext, + resp: &mut Result, + ) { + if resp.is_ok() + && let Some(id) = ctx.session_id + { + let mut balances = self.balances.lock().await; + + let b = balances.entry(id).or_insert(0); + *b = ctx.balance; + } } - } } - - - #[tokio::main] async fn main() -> anyhow::Result<()> { - let (client_transport, server_transport) = create_channel(); - - let client_transport = client_transport.with(|f: ClientMessage| async move { + let (client_transport, server_transport) = create_channel(); + + let client_transport = + client_transport.with(|f: ClientMessage| { + async move { + if let ClientMessage::Request(Request { + ref context, + ref message, + .. + }) = f + { + println!("msg = {:?}, ctx = {:?}", message, context); + tokio::time::sleep(Duration::from_secs( + context.delay_sending_by_seconds as u64, + )) + .await; + } + + Ok(f) + } + .boxed() + }); + + let server = server::BaseChannel::with_defaults(server_transport); + let hook = SessionHook { + balances: Arc::new(Mutex::new(HashMap::new())), + }; + tokio::spawn( + server + .execute(HelloServer.serve().before_and_after(hook)) + .for_each(spawn), + ); + + // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` + // that takes a config and any Transport as input. + let client = WorldClient::new(client::Config::default(), client_transport).spawn(); + + // The client has an RPC method for each RPC defined in the annotated trait. It takes the same + // args as defined, with the addition of a Context, which is always the first arg. The Context + // specifies a deadline and trace information which can be helpful in debugging requests. + + let mut client_context: ClientContext = ClientContext { + session_id: None, + delay_sending_by_seconds: 1, + }; + + let hello = client + .hello(&mut client_context, "Stan".to_string()) + .await?; + + assert_eq!(hello, Err("Session not yet initialized!".to_string())); + + let _ = client.create_session(&mut client_context).await?; + let hello = client + .hello(&mut client_context, "Stan".to_string()) + .await?; + assert_eq!(hello, Err("Give me more money".to_string())); + + let _ = client.increase_balance(&mut client_context, 2u32).await?; + + let hello = client + .hello(&mut client_context, "Stan".to_string()) + .await?; + assert_eq!(hello, Ok("Hello, Stan!".to_string())); + + let hello = client + .hello(&mut client_context, "Frank".to_string()) + .await?; + assert_eq!(hello, Ok("Hello, Frank!".to_string())); + + let hello = client + .hello(&mut client_context, "Joshua".to_string()) + .await?; + assert_eq!(hello, Err("Give me more money".to_string())); - if let ClientMessage::Request(Request {ref context, ref message, .. }) = f { - println!("msg = {:?}, ctx = {:?}", message, context); - tokio::time::sleep(Duration::from_secs(context.delay_sending_by_seconds as u64)).await; - } - - Ok(f) - }.boxed()); - - let server = server::BaseChannel::with_defaults(server_transport); - let hook = SessionHook { balances: Arc::new(Mutex::new(HashMap::new())) }; - tokio::spawn(server.execute(HelloServer.serve().before_and_after(hook)).for_each(spawn)); - - // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` - // that takes a config and any Transport as input. - let client = WorldClient::new(client::Config::default(), client_transport).spawn(); - - // The client has an RPC method for each RPC defined in the annotated trait. It takes the same - // args as defined, with the addition of a Context, which is always the first arg. The Context - // specifies a deadline and trace information which can be helpful in debugging requests. - - let mut client_context: ClientContext = ClientContext { - session_id: None, - delay_sending_by_seconds: 1, - }; - - let hello = client.hello(&mut client_context, "Stan".to_string()).await?; - - assert_eq!(hello, Err("Session not yet initialized!".to_string())); - - let _ = client.create_session(&mut client_context).await?; - let hello = client.hello(&mut client_context, "Stan".to_string()).await?; - assert_eq!(hello, Err("Give me more money".to_string())); - - let _ = client.increase_balance(&mut client_context, 2u32).await?; - - let hello = client.hello(&mut client_context, "Stan".to_string()).await?; - assert_eq!(hello, Ok("Hello, Stan!".to_string())); - - let hello = client.hello(&mut client_context, "Frank".to_string()).await?; - assert_eq!(hello, Ok("Hello, Frank!".to_string())); - - let hello = client.hello(&mut client_context, "Joshua".to_string()).await?; - assert_eq!(hello, Err("Give me more money".to_string())); - - Ok(()) + Ok(()) } //*** Helper functions below ***// mod absolute_to_relative_time { - pub use serde::{Deserialize, Deserializer, Serialize, Serializer}; - pub use std::time::{Duration, Instant}; - - pub fn serialize(deadline: &Instant, serializer: S) -> Result - where - S: Serializer, - { - let deadline = deadline.duration_since(Instant::now()); - deadline.serialize(serializer) - } - - pub fn deserialize<'de, D>(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let deadline = Duration::deserialize(deserializer)?; - Ok(Instant::now() + deadline) - } -} + pub use serde::{Deserialize, Deserializer, Serialize, Serializer}; + pub use std::time::{Duration, Instant}; + + pub fn serialize(deadline: &Instant, serializer: S) -> Result + where + S: Serializer, + { + let deadline = deadline.duration_since(Instant::now()); + deadline.serialize(serializer) + } -fn map_request_context(req: ClientMessage, f: impl FnOnce(Ctx) -> Ctx2) -> ClientMessage { - match req { - ClientMessage::Request(Request { context, id, message}) => ClientMessage::Request(Request { context: f(context), id, message }), - ClientMessage::Cancel { trace_context, request_id } => ClientMessage::Cancel { trace_context, request_id }, - _ => unimplemented!() - } + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let deadline = Duration::deserialize(deserializer)?; + Ok(Instant::now() + deadline) + } } -fn map_response_context(res: Response, f: impl FnOnce(Ctx) -> Ctx2) -> Response { - Response { - request_id: res.request_id, - context: f(res.context), - message: res.message - } +fn map_request_context( + req: ClientMessage, + f: impl FnOnce(Ctx) -> Ctx2, +) -> ClientMessage { + match req { + ClientMessage::Request(Request { + context, + id, + message, + }) => ClientMessage::Request(Request { + context: f(context), + id, + message, + }), + ClientMessage::Cancel { + trace_context, + request_id, + } => ClientMessage::Cancel { + trace_context, + request_id, + }, + _ => unimplemented!(), + } } -fn create_channel() -> (impl Transport, Response>, impl Transport, ClientMessage>) { - let (client, server): (UnboundedChannel, ClientMessage>, UnboundedChannel, Response>) = tarpc::transport::channel::unbounded(); - - let client = client.with(|m| futures::future::ok(map_request_context(m, |c: ClientContext| c.extract()))).map_ok(|r| map_response_context(r, |c: CustomContext| ClientContext { session_id: c.session_id, delay_sending_by_seconds: 0 })); - let server = server.with(|r| futures::future::ok(map_response_context(r, |c: ServerContext| c.extract()))).map_ok(|m| map_request_context(m, |c| ServerContext { deadline: c.deadline, trace_context: c.trace_context, session_id: c.session_id, balance: 0 })); +fn map_response_context( + res: Response, + f: impl FnOnce(Ctx) -> Ctx2, +) -> Response { + Response { + request_id: res.request_id, + context: f(res.context), + message: res.message, + } +} - (client, server) -} \ No newline at end of file +fn create_channel() -> ( + impl Transport, Response>, + impl Transport, ClientMessage>, +) { + let (client, server): ( + UnboundedChannel< + Response, + ClientMessage, + >, + UnboundedChannel< + ClientMessage, + Response, + >, + ) = tarpc::transport::channel::unbounded(); + + let client = client + .with(|m| futures::future::ok(map_request_context(m, |c: ClientContext| c.extract()))) + .map_ok(|r| { + map_response_context(r, |c: CustomSharedContext| ClientContext { + session_id: c.session_id, + delay_sending_by_seconds: 0, + }) + }); + let server = server + .with(|r| futures::future::ok(map_response_context(r, |c: ServerContext| c.extract()))) + .map_ok(|m| { + map_request_context(m, |c| ServerContext { + deadline: c.deadline, + trace_context: c.trace_context, + session_id: c.session_id, + balance: 0, + }) + }); + + (client, server) +} diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index ebd5db63b..ab7726da8 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -20,10 +20,17 @@ use futures::{prelude::*, ready, stream::Fuse, task::*}; use in_flight_requests::InFlightRequests; use pin_project::pin_project; use std::marker::PhantomData; -use std::{any::Any, convert::TryFrom, fmt, pin::Pin, sync::{ - Arc, - atomic::{AtomicUsize, Ordering}, -}, time::SystemTime}; +use std::{ + any::Any, + convert::TryFrom, + fmt, + pin::Pin, + sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, + }, + time::SystemTime, +}; use tokio::sync::{mpsc, oneshot}; use tracing::Span; @@ -275,7 +282,7 @@ where in_flight_requests: InFlightRequests::default(), pending_requests, terminal_error: None, - ghost: PhantomData + ghost: PhantomData, }, } } @@ -712,7 +719,8 @@ where /// the lifecycle of the request. #[derive(Debug)] struct DispatchRequest { - pub ctx: ClientCtx, ///TODO: this should be a &mut ClientCtx + pub ctx: ClientCtx, + ///TODO: this should be a &mut ClientCtx pub span: Span, pub request_id: u64, pub request: Req, diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index cbd7b3b8f..95b5a9f81 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -60,8 +60,7 @@ pub trait SharedContext { /// should still be processed. RPC runtimes and middleware may use this /// value to enforce timeouts, cancel in-flight work, or reject requests /// that have already expired. - fn deadline(&self) -> Instant; - + fn deadline(&self) -> Instant; /// Returns the distributed tracing context associated with the request. /// @@ -70,7 +69,6 @@ pub trait SharedContext { //TODO: May want to remove this in the long run from the default context, may need https://github.com/rust-lang/rust/issues/144361 for that. fn trace_context(&self) -> trace::Context; - /// Updates the distributed tracing context. /// /// This is typically used by middleware to attach spans, propagate @@ -257,7 +255,9 @@ impl SpanExt for tracing::Span { .with_remote_span_context(opentelemetry::trace::SpanContext::new( opentelemetry::trace::TraceId::from(context.trace_context().trace_id), opentelemetry::trace::SpanId::from(context.trace_context().span_id), - opentelemetry::trace::TraceFlags::from(context.trace_context().sampling_decision), + opentelemetry::trace::TraceFlags::from( + context.trace_context().sampling_decision, + ), true, opentelemetry::trace::TraceState::default(), )) diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index f4553450a..42c72d176 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -66,7 +66,7 @@ impl Config { where T: Transport, ClientMessage>, ServerCtx: ExtractContext, - SharedCtx: SharedContext + SharedCtx: SharedContext, { BaseChannel::new(self, transport) } @@ -271,7 +271,9 @@ where } } -impl fmt::Debug for BaseChannel { +impl fmt::Debug + for BaseChannel +{ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "BaseChannel") } @@ -611,7 +613,9 @@ where } } -impl AsRef for BaseChannel { +impl AsRef + for BaseChannel +{ fn as_ref(&self) -> &T { self.transport.get_ref() } From 8348cf6f116effe46cdb27e4437d3955bdc92f7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Vandra-Meyer?= Date: Fri, 26 Dec 2025 13:27:32 +0100 Subject: [PATCH 26/26] separate context extraction and updating, since the server side only needs extraction. --- tarpc/examples/custom_context.rs | 42 +++++++++++++++++--------------- tarpc/src/client.rs | 6 ++--- tarpc/src/client/stub.rs | 4 +-- tarpc/src/context.rs | 38 ++++++++++++++++++++--------- tarpc/src/server.rs | 12 +++------ 5 files changed, 57 insertions(+), 45 deletions(-) diff --git a/tarpc/examples/custom_context.rs b/tarpc/examples/custom_context.rs index 19d490de5..6a76bcb12 100644 --- a/tarpc/examples/custom_context.rs +++ b/tarpc/examples/custom_context.rs @@ -11,7 +11,7 @@ use std::collections::HashMap; use std::ops::Add; use std::sync::Arc; use std::time::{Duration, Instant}; -use tarpc::context::{ExtractContext, SharedContext}; +use tarpc::context::{ExtractContext, SharedContext, UpdateContext}; use tarpc::server::request_hook::{AfterRequest, BeforeRequest, RequestHook}; use tarpc::transport::channel::UnboundedChannel; use tarpc::{ @@ -21,6 +21,8 @@ use tarpc::{ }; use tokio::sync::Mutex; + +/// This is the context that is sent between the client and server. #[derive(Serialize, Deserialize, Clone)] struct CustomSharedContext { #[serde(with = "absolute_to_relative_time")] @@ -29,6 +31,22 @@ struct CustomSharedContext { pub session_id: Option, } +/// This context is only seen by the client side. It can be used by the transport, or any step before sending to dispatch, +/// In this case used by the very simple example of delaying the request, but it can be used for batching, prioritisation, etc. +#[derive(Clone, Debug)] +struct ClientContext { + pub session_id: Option, + pub delay_sending_by_seconds: u32, +} + +/// This context is only seen by the server. It can be used by the transport, the hooks or the service implementation itself to +/// influence its behaviour. In our case the SessionHook extracts the session data +struct ServerContext { + pub deadline: Instant, + pub trace_context: trace::Context, + pub session_id: Option, + pub balance: u64 +} impl SharedContext for CustomSharedContext { fn deadline(&self) -> Instant { self.deadline @@ -43,18 +61,6 @@ impl SharedContext for CustomSharedContext { } } -#[derive(Clone, Debug)] -struct ClientContext { - pub session_id: Option, - pub delay_sending_by_seconds: u32, -} - -struct ServerContext { - pub deadline: Instant, - pub trace_context: trace::Context, - pub session_id: Option, - pub balance: u64, -} impl ExtractContext for ClientContext { fn extract(&self) -> CustomSharedContext { @@ -64,7 +70,9 @@ impl ExtractContext for ClientContext { session_id: self.session_id, } } +} +impl UpdateContext for ClientContext { fn update(&mut self, value: CustomSharedContext) { self.session_id = value.session_id; } @@ -78,17 +86,11 @@ impl ExtractContext for ServerContext { session_id: self.session_id, } } - - fn update(&mut self, value: CustomSharedContext) { - self.deadline = value.deadline; - self.trace_context = value.trace_context; - self.session_id = value.session_id; - } } /// This is the service definition. It looks a lot like a trait definition. /// It defines one RPC, hello, which takes one arg, name, and returns a String. -#[tarpc::service(shared_context = "CustomContext")] +#[tarpc::service(shared_context = "CustomSharedContext")] pub trait World { async fn create_session() -> (); async fn increase_balance(credits: u32) -> (); diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index ab7726da8..7ebe88cb2 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -9,7 +9,7 @@ mod in_flight_requests; pub mod stub; -use crate::context::{ExtractContext, SharedContext}; +use crate::context::{ExtractContext, SharedContext, UpdateContext}; use crate::{ ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport, cancellations::{CanceledRequests, RequestCancellation, cancellations}, @@ -122,8 +122,8 @@ impl Clone for Channel Channel where Req: RequestName, - ClientCtx: ExtractContext + Clone, - SharedCtx: context::SharedContext, + ClientCtx: UpdateContext + Clone, + SharedCtx: SharedContext, { /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that /// resolves to the response. diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index 8f49b31b3..06e0e438d 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -4,7 +4,7 @@ use crate::{ RequestName, client::{Channel, RpcError}, context, - context::ExtractContext, + context::UpdateContext, server::Serve, }; @@ -38,7 +38,7 @@ pub trait Stub { impl Stub for Channel where Req: RequestName, - ClientCtx: ExtractContext + Clone, + ClientCtx: UpdateContext + Clone, SharedCtx: context::SharedContext, { type Req = Req; diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index 95b5a9f81..5a510219b 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -91,9 +91,9 @@ impl SharedContext for DefaultContext { } } -/// Extracts or updates a wire-level shared context contained within a client or server context. +/// Extracts a wire-level shared context contained within a client or server context. /// -/// `ExtractContext` defines a bidirectional mapping between an internal +/// `ExtractContext` defines a mapping between an internal /// context representation and a *shared* context type (`Ctx`) that is /// suitable for serialization and transmission over the wire. /// @@ -102,18 +102,10 @@ impl SharedContext for DefaultContext { /// implementing type may contain additional, local side only state or /// a different internal structure. /// -/// This trait is intentionally symmetric: -/// - [`extract`](Self::extract) converts from the internal representation -/// into the shared, serializable context. -/// - [`update`](Self::update) applies a shared context to the internal -/// representation, updating or reconstructing local state as needed. -/// /// # Design notes /// -/// Implementations are expected to be *lossy* or *lossless* depending on -/// the application’s needs. Any information not representable in `Ctx` -/// must be reconstructed, defaulted, or retained internally by the -/// implementation. +/// If a type implements both UpdateContext and ExtractContext, it is expected that +/// `foo.update(v).extract() == v` will hold. // TODO: Revisit this trait once try_as_dyn is stabilized, https://github.com/rust-lang/rust/issues/29661. pub trait ExtractContext { /// Extracts the inner context from the internal state. @@ -124,6 +116,24 @@ pub trait ExtractContext { /// local context. fn extract(&self) -> Ctx; +} + +/// Updates a wire-level shared context contained within a client context. +/// +/// `ExtractContext` defines a mapping between a *shared* context type (`Ctx`) +/// and an internal context representation +/// +/// The shared context typically represents the minimal, stable data +/// exchanged between the client and server, while the +/// implementing type may contain additional, local side only state or +/// a different internal structure. +/// +/// # Design notes +/// +/// It is expected that `ctx.update(shared_ctx).extract() == shared_ctx` will always hold. +/// +// TODO: Revisit this trait once try_as_dyn is stabilized, https://github.com/rust-lang/rust/issues/29661. +pub trait UpdateContext: ExtractContext { /// Updates the internal state from an inner context value. /// /// This method is typically called after executing a request and before @@ -140,12 +150,16 @@ where fn extract(&self) -> T { self.clone() } +} +impl> UpdateContext for T { fn update(&mut self, value: T) { *self = value } } + + #[cfg(feature = "serde1")] mod absolute_to_relative_time { pub use serde::{Deserialize, Deserializer, Serialize, Serializer}; diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index 42c72d176..47c8476c4 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -1009,7 +1009,7 @@ mod tests { request_hook::{AfterRequest, BeforeRequest, RequestHook}, serve, }; - use crate::context::ExtractContext; + use crate::context::{DefaultContext}; use crate::{ ClientMessage, Request, Response, ServerError, context, trace, transport::channel::{self, UnboundedChannel}, @@ -1136,14 +1136,10 @@ mod tests { #[tokio::test] async fn serve_before_mutates_context() -> anyhow::Result<()> { struct SetDeadline(Instant); - impl BeforeRequest for SetDeadline - where - ServerCtx: ExtractContext, + impl BeforeRequest for SetDeadline { - async fn before(&mut self, ctx: &mut ServerCtx, _: &Req) -> Result<(), ServerError> { - let mut inner = ctx.extract(); - inner.deadline = self.0; - ctx.update(inner); + async fn before(&mut self, ctx: &mut DefaultContext, _: &Req) -> Result<(), ServerError> { + ctx.deadline = self.0; Ok(()) } }